mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
Merge branch 'open-webui:main' into main
This commit is contained in:
commit
415114504f
478 changed files with 64194 additions and 22041 deletions
80
.github/ISSUE_TEMPLATE/bug_report.md
vendored
80
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
|
@ -1,80 +0,0 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
# Bug Report
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
||||
|
||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
||||
|
||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
||||
|
||||
---
|
||||
|
||||
## Installation Method
|
||||
|
||||
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
|
||||
|
||||
## Environment
|
||||
|
||||
- **Open WebUI Version:** [e.g., v0.3.11]
|
||||
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
|
||||
|
||||
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
|
||||
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
|
||||
|
||||
**Confirmation:**
|
||||
|
||||
- [ ] I have read and followed all the instructions provided in the README.md.
|
||||
- [ ] I am on the latest version of both Open WebUI and Ollama.
|
||||
- [ ] I have included the browser console logs.
|
||||
- [ ] I have included the Docker container logs.
|
||||
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
|
||||
|
||||
## Expected Behavior:
|
||||
|
||||
[Describe what you expected to happen.]
|
||||
|
||||
## Actual Behavior:
|
||||
|
||||
[Describe what actually happened.]
|
||||
|
||||
## Description
|
||||
|
||||
**Bug Summary:**
|
||||
[Provide a brief but clear summary of the bug]
|
||||
|
||||
## Reproduction Details
|
||||
|
||||
**Steps to Reproduce:**
|
||||
[Outline the steps to reproduce the bug. Be as detailed as possible.]
|
||||
|
||||
## Logs and Screenshots
|
||||
|
||||
**Browser Console Logs:**
|
||||
[Include relevant browser console logs, if applicable]
|
||||
|
||||
**Docker Container Logs:**
|
||||
[Include relevant Docker container logs, if applicable]
|
||||
|
||||
**Screenshots/Screen Recordings (if applicable):**
|
||||
[Attach any relevant screenshots to help illustrate the issue]
|
||||
|
||||
## Additional Information
|
||||
|
||||
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
|
||||
|
||||
## Note
|
||||
|
||||
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!
|
||||
167
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
167
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
name: Bug Report
|
||||
description: Create a detailed bug report to help us improve Open WebUI.
|
||||
title: 'issue: '
|
||||
labels: ['bug', 'triage']
|
||||
assignees: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Bug Report
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
|
||||
|
||||
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
|
||||
|
||||
- **Contributing**: If you encounter an issue, consider submitting a pull request or forking the project. We prioritize preventing contributor burnout to maintain Open WebUI's quality.
|
||||
|
||||
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
|
||||
|
||||
- type: checkboxes
|
||||
id: issue-check
|
||||
attributes:
|
||||
label: Check Existing Issues
|
||||
description: Confirm that you’ve checked for existing reports before submitting a new one.
|
||||
options:
|
||||
- label: I have searched the existing issues and discussions.
|
||||
required: true
|
||||
- label: I am using the latest version of Open WebUI.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: installation-method
|
||||
attributes:
|
||||
label: Installation Method
|
||||
description: How did you install Open WebUI?
|
||||
options:
|
||||
- Git Clone
|
||||
- Pip Install
|
||||
- Docker
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: open-webui-version
|
||||
attributes:
|
||||
label: Open WebUI Version
|
||||
description: Specify the version (e.g., v0.3.11)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: ollama-version
|
||||
attributes:
|
||||
label: Ollama Version (if applicable)
|
||||
description: Specify the version (e.g., v0.2.0, or v0.1.32-rc1)
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: operating-system
|
||||
attributes:
|
||||
label: Operating System
|
||||
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: browser
|
||||
attributes:
|
||||
label: Browser (if applicable)
|
||||
description: Specify the browser/version (e.g., Chrome 100.0, Firefox 98.0)
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: confirmation
|
||||
attributes:
|
||||
label: Confirmation
|
||||
description: Ensure the following prerequisites have been met.
|
||||
options:
|
||||
- label: I have read and followed all instructions in `README.md`.
|
||||
required: true
|
||||
- label: I am using the latest version of **both** Open WebUI and Ollama.
|
||||
required: true
|
||||
- label: I have included the browser console logs.
|
||||
required: true
|
||||
- label: I have included the Docker container logs.
|
||||
required: true
|
||||
- label: I have **provided every relevant configuration, setting, and environment variable used in my setup.**
|
||||
required: true
|
||||
- label: I have clearly **listed every relevant configuration, custom setting, environment variable, and command-line option that influences my setup** (such as Docker Compose overrides, .env values, browser settings, authentication configurations, etc).
|
||||
required: true
|
||||
- label: |
|
||||
I have documented **step-by-step reproduction instructions that are precise, sequential, and leave nothing to interpretation**. My steps:
|
||||
- Start with the initial platform/version/OS and dependencies used,
|
||||
- Specify exact install/launch/configure commands,
|
||||
- List URLs visited, user input (incl. example values/emails/passwords if needed),
|
||||
- Describe all options and toggles enabled or changed,
|
||||
- Include any files or environmental changes,
|
||||
- Identify the expected and actual result at each stage,
|
||||
- Ensure any reasonably skilled user can follow and hit the same issue.
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: Describe what should have happened.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: Describe what actually happened.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction-steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: |
|
||||
Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables.
|
||||
**If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.**
|
||||
|
||||
placeholder: |
|
||||
Example (include every detail):
|
||||
1. Start with a clean Ubuntu 22.04 install.
|
||||
2. Install Docker v24.0.5 and start the service.
|
||||
3. Clone the Open WebUI repo (git clone ...).
|
||||
4. Use the Docker Compose file without modifications.
|
||||
5. Open browser Chrome 115.0 in incognito mode.
|
||||
6. Go to http://localhost:8080 and log in with user "test@example.com".
|
||||
7. Set the language to "English" and theme to "Dark".
|
||||
8. Attempt to connect to Ollama at "http://localhost:11434".
|
||||
9. Observe that the error message "Connection refused" appears at the top right.
|
||||
|
||||
Please list each step carefully and include all relevant configuration, settings, and options.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs-screenshots
|
||||
attributes:
|
||||
label: Logs & Screenshots
|
||||
description: Include relevant logs, errors, or screenshots to help diagnose the issue.
|
||||
placeholder: 'Attach logs from the browser console, Docker logs, or error messages.'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-info
|
||||
attributes:
|
||||
label: Additional Information
|
||||
description: Provide any extra details that may assist in understanding the issue.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Note
|
||||
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
|
||||
Thank you for contributing to Open WebUI!
|
||||
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
blank_issues_enabled: false
|
||||
35
.github/ISSUE_TEMPLATE/feature_request.md
vendored
35
.github/ISSUE_TEMPLATE/feature_request.md
vendored
|
|
@ -1,35 +0,0 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
# Feature Request
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
||||
|
||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
||||
|
||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
64
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
Normal file
64
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
name: Feature Request
|
||||
description: Suggest an idea for this project
|
||||
title: 'feat: '
|
||||
labels: ['triage']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Important Notes
|
||||
### Before submitting
|
||||
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
|
||||
It's likely we're already tracking it! If you’re unsure, start a discussion post first.
|
||||
This will help us efficiently focus on improving the project.
|
||||
|
||||
### Collaborate respectfully
|
||||
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
|
||||
|
||||
Remember:
|
||||
- Open WebUI is a **volunteer-driven project**
|
||||
- It's managed by a **single maintainer**
|
||||
- It's supported by contributors who also have **full-time jobs**
|
||||
|
||||
We appreciate your time and ask that you **respect ours**.
|
||||
|
||||
|
||||
### Contributing
|
||||
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
### Bug reproducibility
|
||||
If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a `pip install` with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "[issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
- type: checkboxes
|
||||
id: existing-issue
|
||||
attributes:
|
||||
label: Check Existing Issues
|
||||
description: Please confirm that you've checked for existing similar requests
|
||||
options:
|
||||
- label: I have searched the existing issues and discussions.
|
||||
required: true
|
||||
- type: textarea
|
||||
id: problem-description
|
||||
attributes:
|
||||
label: Problem Description
|
||||
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
|
||||
placeholder: "Ex. I'm always frustrated when..."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: solution-description
|
||||
attributes:
|
||||
label: Desired Solution you'd like
|
||||
description: Clearly describe what you want to happen.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: alternatives-considered
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: A clear and concise description of any alternative solutions or features you've considered.
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
14
.github/dependabot.yml
vendored
14
.github/dependabot.yml
vendored
|
|
@ -1,12 +1,26 @@
|
|||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: uv
|
||||
directory: '/'
|
||||
schedule:
|
||||
interval: monthly
|
||||
target-branch: 'dev'
|
||||
|
||||
- package-ecosystem: pip
|
||||
directory: '/backend'
|
||||
schedule:
|
||||
interval: monthly
|
||||
target-branch: 'dev'
|
||||
|
||||
- package-ecosystem: npm
|
||||
directory: '/'
|
||||
schedule:
|
||||
interval: monthly
|
||||
target-branch: 'dev'
|
||||
|
||||
- package-ecosystem: 'github-actions'
|
||||
directory: '/'
|
||||
schedule:
|
||||
# Check for updates to GitHub Actions every week
|
||||
interval: monthly
|
||||
target-branch: 'dev'
|
||||
|
|
|
|||
10
.github/pull_request_template.md
vendored
10
.github/pull_request_template.md
vendored
|
|
@ -9,9 +9,9 @@
|
|||
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
|
||||
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
|
||||
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
|
||||
- [ ] **Testing:** Have you written and run sufficient tests for validating the changes?
|
||||
- [ ] **Testing:** Have you written and run sufficient tests to validate the changes?
|
||||
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
|
||||
- [ ] **Prefix:** To cleary categorize this pull request, prefix the pull request title, using one of the following:
|
||||
- [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
|
||||
- **BREAKING CHANGE**: Significant changes that may affect compatibility
|
||||
- **build**: Changes that affect the build system or external dependencies
|
||||
- **ci**: Changes to our continuous integration processes or workflows
|
||||
|
|
@ -22,7 +22,7 @@
|
|||
- **i18n**: Internationalization or localization changes
|
||||
- **perf**: Performance improvement
|
||||
- **refactor**: Code restructuring for better maintainability, readability, or scalability
|
||||
- **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
|
||||
- **style**: Changes that do not affect the meaning of the code (white space, formatting, missing semi-colons, etc.)
|
||||
- **test**: Adding missing tests or correcting existing tests
|
||||
- **WIP**: Work in progress, a temporary label for incomplete or ongoing work
|
||||
|
||||
|
|
@ -70,3 +70,7 @@
|
|||
### Screenshots or Videos
|
||||
|
||||
- [Attach any relevant screenshots or videos demonstrating the changes]
|
||||
|
||||
### Contributor License Agreement
|
||||
|
||||
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.
|
||||
|
|
|
|||
6
.github/workflows/docker-build.yaml
vendored
6
.github/workflows/docker-build.yaml
vendored
|
|
@ -14,7 +14,7 @@ env:
|
|||
|
||||
jobs:
|
||||
build-main-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
@ -111,7 +111,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
build-cuda-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
@ -211,7 +211,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
build-ollama-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
|
|||
14
.github/workflows/format-backend.yaml
vendored
14
.github/workflows/format-backend.yaml
vendored
|
|
@ -5,10 +5,18 @@ on:
|
|||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
|
@ -17,7 +25,9 @@ jobs:
|
|||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.11]
|
||||
python-version:
|
||||
- 3.11.x
|
||||
- 3.12.x
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
@ -25,7 +35,7 @@ jobs:
|
|||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
|
|
|||
10
.github/workflows/format-build-frontend.yaml
vendored
10
.github/workflows/format-build-frontend.yaml
vendored
|
|
@ -5,10 +5,18 @@ on:
|
|||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- 'backend/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- 'backend/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
|
@ -21,7 +29,7 @@ jobs:
|
|||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22' # Or specify any other version you want to use
|
||||
node-version: '22'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: npm install
|
||||
|
|
|
|||
6
.github/workflows/release-pypi.yml
vendored
6
.github/workflows/release-pypi.yml
vendored
|
|
@ -17,9 +17,13 @@ jobs:
|
|||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install Git
|
||||
run: sudo apt-get update && sudo apt-get install -y git
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 18
|
||||
node-version: 22
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
|
|
|||
491
CHANGELOG.md
491
CHANGELOG.md
|
|
@ -5,6 +5,497 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.6.13] - 2025-05-30
|
||||
|
||||
### Added
|
||||
|
||||
- 🟦 **Azure OpenAI Embedding Support**: You can now select Azure OpenAI endpoints for text embeddings, unlocking seamless integration with enterprise-scale Azure AI for powerful RAG and knowledge workflows—no more workarounds, connect and scale effortlessly.
|
||||
- 🧩 **Smarter Custom Parameter Handling**: Instantly enjoy more flexible model setup—any JSON pasted into custom parameter fields is now parsed automatically, so you can define rich, nested parameters without tedious manual adjustment. This streamlines advanced configuration for all models and accelerates experimentation.
|
||||
- ⚙️ **General Backend Refactoring**: Significant backend improvements deliver a cleaner codebase for better maintainability, faster performance, and even greater platform reliability—making all your workflows run more smoothly.
|
||||
- 🌏 **Localization Upgrades**: Experience highly improved user interface translations and clarity in Simplified, Traditional Chinese, Korean, and Finnish, offering a more natural, accurate, and accessible experience for global users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛡️ **Robust Message Handling on Chat Load**: Fixed an issue where chat pages could fail to load if a referenced message was missing or undefined; now, chats always load smoothly and missing IDs no longer disrupt your workflow.
|
||||
- 📝 **Correct Prompt Access Control**: Ensured that the prompt access controls register properly, restoring reliable permissioning and safeguarding your prompt workflows.
|
||||
- 🛠 **Open WebUI-Specific Params No Longer Sent to Models**: Fixed a bug that sent internal WebUI parameters to APIs, ensuring only intended model options are transmitted—restoring predictable, error-free model operation.
|
||||
- 🧠 **Refined Memory Error Handling**: Enhanced stability during memory-related operations, so even uncommon memory errors are gracefully managed without disrupting your session—resulting in a more reliable, worry-free experience.
|
||||
|
||||
## [0.6.12] - 2025-05-29
|
||||
|
||||
### Added
|
||||
|
||||
- 🧩 **Custom Advanced Model Parameters**: You can now add your own tailor-made advanced parameters to any model, empowering you to fine-tune behavior and unlock greater flexibility beyond just the built-in options—accelerate your experimentation.
|
||||
- 🪧 **Datalab Marker API Content Extraction Support**: Seamlessly extract content from files and documents using the Datalab Marker API directly in your workflows, enabling more robust structured data extraction for RAG and document processing with just a simple engine switch in the UI.
|
||||
- ⚡ **Parallelized Base Model Fetching**: Experience noticeably faster startup and model refresh times—base model data now loads in parallel, drastically shortening delays in busy or large-scale deployments.
|
||||
- 🧠 **Efficient Function Loading and Caching**: Functions are now only reloaded if their content changes, preventing unnecessary duplicate loads, saving bandwidth, and boosting performance.
|
||||
- 🌍 **Localization & Translation Enhancements**: Improved and expanded Simplified, Traditional Chinese, and Russian translations, providing smoother, more accurate, and context-aware experiences for global users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 💬 **Stable Message Input Box**: Fixed an issue where the message input box would shift unexpectedly (especially on mobile or with screen reader support), ensuring a smooth and reliable typing experience for every user.
|
||||
- 🔊 **Reliable Read Aloud (Text-to-Speech)**: Read aloud now works seamlessly across messages, so users depending on TTS for accessibility or multitasking will experience uninterrupted and clear voice playback.
|
||||
- 🖼 **Image Preview and Download Restored**: Fixed problems with image preview and downloads, ensuring frictionless creation, previewing, and downloading of images in your chats—no more interruptions in creative or documentation workflows.
|
||||
- 📱 **Improved Mobile Styling for Workspace Capabilities**: Capabilities management is now readable and easy-to-use even on mobile devices, empowering admins and users to manage access quickly on the go.
|
||||
- 🔁 **/api/v1/retrieval/query/collection Endpoint Reliability**: Queries to retrieval collections now return the expected results, bolstering the reliability of your knowledge workflows and citation-ready responses.
|
||||
|
||||
### Removed
|
||||
|
||||
- 🧹 **Duplicate CSS Elements**: Streamlined the UI by removing redundant CSS, reducing clutter and improving load times for a smoother visual experience.
|
||||
|
||||
## [0.6.11] - 2025-05-27
|
||||
|
||||
### Added
|
||||
|
||||
- 🟢 **Ollama Model Status Indicator in Model Selector**: Instantly see which Ollama models are currently loaded with a clear indicator in the model selector, helping you stay organized and optimize local model usage.
|
||||
- 🗑️ **Unload Ollama Model Directly from Model Selector**: Easily release memory and resources by unloading any loaded Ollama model right in the model selector—streamline hardware management without switching pages.
|
||||
- 🗣️ **User-Configurable Speech-to-Text Language Setting**: Improve transcription accuracy by letting individual users explicitly set their preferred STT language in their settings—ideal for multilingual teams and clear audio capture.
|
||||
- ⚡ **Granular Audio Playback Speed Control**: Instead of just presets, you can now choose granular audio speed using a numeric input, giving you complete control over playback pace in transcriptions and media reviews.
|
||||
- 📦 **GZip, Brotli, ZStd Compression Middleware**: Enjoy significantly faster page loads and reduced bandwidth usage with new server-side compression—giving users a snappier, more efficient experience.
|
||||
- 🏷️ **Configurable Weight for BM25 in Hybrid Search**: Fine-tune search relevance by adjusting the weight for BM25 inside hybrid search from the UI, letting you tailor knowledge search results to your workflow.
|
||||
- 🧪 **Bypass File Creation with CTRL + SHIFT + V**: When “Paste Large Text as File” is enabled, use CTRL + SHIFT + V to skip the file creation dialog and instantly upload text as a file—perfect for rapid document prep.
|
||||
- 🌐 **Bypass Web Loader in Web Search**: Choose to bypass web content loading and use snippets directly in web search for faster, more reliable results when page loads are slow or blocked.
|
||||
- 🚀 **Environment Variable: WEBUI_AUTH_TRUSTED_GROUPS_HEADER**: Now sync and manage user groups directly via trusted HTTP header, unlocking smoother single sign-on and identity integrations for organizations.
|
||||
- 🏢 **Workspace Models Visibility Controls**: You can now hide workspace-level models from both the model selector and shared environments—keep your team focused and reduce clutter from rarely-used endpoints.
|
||||
- 🛡️ **Copy Model Link**: You can now copy a direct link to any model—including those hidden from the selector—making sharing and onboarding others more seamless.
|
||||
- 🔗 **Load Function Directly from URL**: Simplify custom function management—just paste any GitHub function URL into Open WebUI and import new functions in seconds.
|
||||
- ⚙️ **Custom Name/Description for External Tool Servers**: Personalize and clarify external tool servers by assigning custom names and descriptions, making it easier to manage integrations in large-scale workspaces.
|
||||
- 🌍 **Custom OpenAPI JSON URL Support for Tool Servers**: Supports specifying any custom OpenAPI JSON URL, unlocking more flexible integration with any backend for tool calls.
|
||||
- 📊 **Source Field Now Displays in Non-Streaming Responses with Attachments**: When files or knowledge are attached, the "source" field now appears for all responses, even in non-streaming mode—enabling improved citation workflow.
|
||||
- 🎛 **Pinned Chats**: Reduced payload size on pinned chat requests—leading to faster load times and less data usage, especially on busy warehouses.
|
||||
- 🛠 **Import/Export Default Prompt Suggestions**: Enjoy one-click import/export of prompt suggestions, making it much easier to share, reuse, and manage best practices across teams or deployments.
|
||||
- 🍰 **Banners Now Sortable from Admin Settings**: Quickly re-order or prioritize banners, letting you highlight the most critical info for your team.
|
||||
- 🛠 **Advanced Chat Parameters—Clearer Ollama Support Labels**: Parameters and advanced settings now explicitly indicate if they are Ollama-specific, reducing confusion and improving setup accuracy.
|
||||
- 🤏 **Scroll Bar Thumb Improved for Better Visibility**: Enhanced scrollbar styling makes navigation more accessible and visually intuitive.
|
||||
- 🗄️ **Modal Redesign for Archived and User Chat Listings**: Clean, modern modal interface for browsing archived and user-specific chats makes locating conversations faster and more pleasant.
|
||||
- 📝 **Add/Edit Memory Modal UX**: Memory modals are now larger and have resizable input fields, supporting easier editing of long or complex memory content.
|
||||
- 🏆 **Translation & Localization Enhancements**: Major upgrades to Chinese (Simplified & Traditional), Korean, Russian, German, Danish, Finnish—not just fixing typos, but consistency, tone, and terminology for a more natural native-language experience.
|
||||
- ⚡ **General Backend Stability & Security Enhancements**: Various backend refinements ensure a more resilient, reliable, and secure platform for smoother operation and peace of mind.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🖼️ **Image Generation with Allowed File Extensions Now Works Reliably**: Ensure seamless image generation even when strict file extension rules are set—no more blocked creative workflows due to technical hiccups.
|
||||
- 🗂 **Remove Leading Dot for File Extension Check**: Fixed an issue where file validation failed because of a leading dot, making file uploads and knowledge management more robust.
|
||||
- 🏷️ **Correct Local/External Model Classification**: The platform now accurately distinguishes between local and external models—preventing local models from showing up as external (and vice versa)—ensuring seamless setup, clarity, and management of your AI model endpoints.
|
||||
- 📄 **External Document Loader Now Functions as Intended**: External document loaders are reliably invoked, ensuring smoother knowledge ingestion from external sources—expanding your RAG and knowledge workflows.
|
||||
- 🎯 **Correct Handling of Toggle Filters**: Toggle filters are now robustly managed, preventing accidental auto-activation and ensuring user preferences are always respected.
|
||||
- 🗃 **S3 Tagging Character Restrictions Fixed**: Tags for files in S3 now automatically meet Amazon’s allowed character set, avoiding upload errors and ensuring cross-cloud compatibility.
|
||||
- 🛡️ **Authentication Now Uses Password Hash When Duplicate Emails Exist**: Ensures account security and prevents access issues if duplicate emails are present in your system.
|
||||
|
||||
### Changed
|
||||
|
||||
- 🧩 **Admin Settings: OAuth Redirects Now Use WEBUI_URL**: The OAuth redirect URL is now based on the explicitly set WEBUI_URL, ensuring single sign-on and identity provider integrations always send users to the correct frontend.
|
||||
|
||||
### Removed
|
||||
|
||||
- 💡 **Duplicate/Typo Component Removals**: Obsolete components have been cleaned up, reducing confusion and improving overall code quality for the team.
|
||||
- 🚫 **Streaming Upsert in Pinecone Removed**: Removed streaming upsert references for better compatibility and future-proofing with latest Pinecone SDK updates.
|
||||
|
||||
## [0.6.10] - 2025-05-19
|
||||
|
||||
### Added
|
||||
|
||||
- 🧩 **Experimental Azure OpenAI Support**: Instantly connect to Azure OpenAI endpoints by simply pasting your Azure OpenAI URL into the model connections—bringing flexible, enterprise-grade AI integration directly to your workflow.
|
||||
- 💧 **Watermark AI Responses**: Easily add a visible watermark to AI-generated responses for clear content provenance and compliance with EU AI regulations—perfect for regulated environments and transparent communication.
|
||||
- 🔍 **Enhanced Search Experience with Dedicated Modal**: Enjoy a modern, powerful search UI in a dedicated modal (open with Ctrl/Cmd + K) accessible from anywhere—quickly find chats, models, or content and boost your productivity.
|
||||
- 🔲 **"Toggle" Filter Type for Chat Input**: Add interactive toggle filters (e.g. Web Search, Image, Code Interpreter) right into the chat input—giving you one-click control to activate features and simplifying the chat configuration process.
|
||||
- 🧰 **Granular Model Capabilities Editor**: Define detailed capabilities and feature access for each AI model directly from the model editor—enabling tailored behavior, modular control, and a more customized AI environment for every team or use case.
|
||||
- 🌐 **Flexible Local and External Connection Support**: Now you can designate any AI connection—whether OpenAI, Ollama, or others—as local or external, enabling seamless setup for on-premise, self-hosted, or cloud configurations and giving you maximum control and flexibility.
|
||||
- 🗂️ **Allowed File Extensions for RAG**: Gain full control over your Retrieval-Augmented Generation (RAG) workflows by specifying exactly which file extensions are permitted for upload, improving security and relevance of indexed documents.
|
||||
- 🔊 **Enhanced Audio Transcription Logic**: Experience smoother, more reliable audio transcription—very long audio files are now automatically split and processed in segments, preventing errors and ensuring even challenging files are transcribed seamlessly, all part of a broader stability upgrade for robust media workflows.
|
||||
- 🦾 **External Document Loader Support**: Enhance knowledge base building by integrating documents using external loaders from a wide range of data sources, expanding what your AI can read and process.
|
||||
- 📝 **Preview Button for Code Artifacts**: Instantly jump from an HTML code block to its associated artifacts page with the click of a new preview button—speeding up review and streamlining analysis.
|
||||
- 🦻 **Screen Reader Support for Response Messages**: All chat responses are now fully compatible with screen readers, making the platform more inclusive and accessible for everyone.
|
||||
- 🧑💼 **Customizable Pending User Overlay**: You can now tailor the overlay title and content shown to pending users, ensuring onboarding messaging is perfectly aligned with your organization’s tone and requirements.
|
||||
- 🔐 **Option to Disable LDAP Certificate Validation**: You can now disable LDAP certificate validation for maximum flexibility in diverse IT environments—making integrations and troubleshooting much easier.
|
||||
- 🎯 **Workspace Search by Username or Email**: Easily search across workspace pages using any username or email address, streamlining user and resource management.
|
||||
- 🎨 **High Contrast & Dark Mode Enhancements**: Further improved placeholder, input, suggestion, toast, and model selector contrasts—including a dedicated placeholder dark mode—for more comfortable viewing in all lighting conditions.
|
||||
- 🛡️ **Refined Security for Pipelines & Model Uploads**: Strengthened safeguards against path traversal vulnerabilities during uploads—ensuring your platform’s document and model management remains secure.
|
||||
- 🌏 **Major Localization Upgrades**: Comprehensive translation updates and improvements across Korean, Bulgarian, Catalan, Japanese, Italian, Traditional Chinese, and Spanish—including more accurate AI terminology for clarity; your experience is now more natural, inclusive, and professional everywhere.
|
||||
- 🦾 **General Backend Stability & Security**: Multiple backend improvements (including file upload, command navigation, and logging refactorings) deliver increased resilience, better error handling, and a more robust platform for all users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- ✅ **Evaluation Feedback Endpoint Reliability**: Addressed issues with feedback submission endpoints to ensure seamless user feedback collection on model responses.
|
||||
- 🫰 **Model List State Fixes**: Resolved issues where model status toggles in the workspace page might inadvertently switch or confuse state, making the management of active/inactive models more dependable.
|
||||
- ✍️ **Admin Signup Logic**: Admin self-signup experience and validation flow is smoother and more robust.
|
||||
- 🔁 **Signout Redirect Flow Improved**: Logging out now redirects more reliably, reducing confusion and making session management seamless.
|
||||
|
||||
## [0.6.9] - 2025-05-10
|
||||
|
||||
### Added
|
||||
|
||||
- 📝 **Edit Attached Images/Files in Messages**: You can now easily edit your sent messages by removing attached files—streamlining document management, correcting mistakes on the fly, and keeping your chats clutter-free.
|
||||
- 🚨 **Clear Alerts for Private Task Models**: When interacting with private task models, the UI now clearly alerts you—making it easier to understand resource availability and access, reducing confusion during workflow setup.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛡️ **Confirm Dialog Focus Trap Reliability**: The focus now stays correctly within confirmation dialogs, ensuring keyboard navigation and accessibility is seamless and preventing accidental operations—especially helpful during critical or rapid workflows.
|
||||
- 💬 **Temporary Chat Admin Controls & Session Cleanliness**: Admins are now able to properly enable temporary chat mode without errors, and previous session prompts or tool selections no longer carry over—delivering a fresh, predictable, and consistent temporary chat experience every time.
|
||||
- 🤖 **External Reranker Integration Functionality Restored**: External reranker integrations now work correctly, allowing you to fully leverage advanced ranking services for sharper, more relevant search results in your RAG and knowledge base workflows.
|
||||
|
||||
## [0.6.8] - 2025-05-10
|
||||
|
||||
### Added
|
||||
|
||||
- 🏆 **External Reranker Support for Knowledge Base Search**: Supercharge your Retrieval-Augmented Generation (RAG) workflows with the new External Reranker integration; easily plug in advanced reranking services via the UI to deliver sharper and more relevant search results, accelerating research and insight discovery.
|
||||
- 📤 **Unstylized PDF Export Option (Reduced File Size)**: When exporting chat transcripts or documents, you can now choose an unstylized PDF export for snappier downloads, minimal file size, and clean data archiving—perfect for large-scale storage or sharing.
|
||||
- 📝 **Vazirmatn Font for Persian & Arabic**: Arabic and Persian users will now see their text beautifully rendered with the specialized Vazirmatn font for an improved localized reading experience.
|
||||
- 🏷️ **SharePoint Tenant ID Support for OneDrive**: You can now specify a SharePoint tenant ID in OneDrive settings for seamless authentication and granular enterprise integration.
|
||||
- 👤 **Refresh OAuth Profile Picture**: Your OAuth profile picture now updates in real-time, ensuring your presence and avatar always match your latest identity across integrated platforms.
|
||||
- 🔧 **Milvus Configuration Improvements**: Configure index and metric types for Milvus directly within settings; take full control of your vector database for more accurate and robust AI search experiences.
|
||||
- 🛡️ **S3 Tagging Toggle for Compatibility**: Optional S3 tagging via an environment toggle grants full compatibility with all storage backends—including those that don’t support tagging like Cloudflare R2—ensuring error-free attachment and document management.
|
||||
- 👨🦯 **Icon Button Accessibility Improvements**: Key interactive icon-buttons now include aria-labels and ARIA descriptions, so screen readers provide precise guidance about what action each button performs for improved accessibility.
|
||||
- ♿ **Enhanced Accessibility with Modal Focus Trap**: Modal dialogs and pop-ups now feature a focus trap and improved ARIA roles, ensuring seamless navigation and screen reader support—making the interface friendlier for everyone, including keyboard and assistive tech users.
|
||||
- 🏃 **Improved Admin User List Loading Indicator**: The user list loading experience is now clearer and more responsive in the admin panel.
|
||||
- 🧑🤝🧑 **Larger Admin User List Page Size**: Admins can now manage up to 30 users per page in the admin interface, drastically reducing pagination and making large user teams easier and faster to manage.
|
||||
- 🌠 **Default Code Interpreter Prompt Clarified**: The built-in code interpreter prompt is now more explicit, preventing AI from wrapping code in Markdown blocks when not needed—ensuring properly formatted code runs as intended every time.
|
||||
- 🧾 **Improved Default Title Generation Prompt Template**: Title generation now uses a robust template for reliable JSON output, improving chat organization and searchability.
|
||||
- 🔗 **Support Jupyter Notebooks with Non-Root Base URLs**: Notebook-based code execution now supports non-root deployed Jupyter servers, granting full flexibility for hybrid or multi-user setups.
|
||||
- 📰 **UI Scrollbar Always Visible for Overflow Tools**: When available tools overflow the display, the scrollbar is now always visible and there’s a handy "show all" toggle, making navigation of large toolsets snappier and more intuitive.
|
||||
- 🛠️ **General Backend Refactoring for Stability**: Multiple under-the-hood improvements have been made across backend components, ensuring smoother performance, fewer errors, and a more reliable overall experience for all users.
|
||||
- 🚀 **Optimized Web Search for Faster Results**: Web search speed and performance have been significantly enhanced, delivering answers and sources in record time to accelerate your research-heavy workflows.
|
||||
- 💡 **More Supported Languages**: Expanded language support ensures an even wider range of users can enjoy an intuitive and natural interface in their native tongue.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🏃♂️ **Exhausting Workers in Nginx Reverse Proxy Due to Websocket Fix**: Websocket sessions are now fully compatible behind Nginx, eliminating worker exhaustion and restoring 24/7 reliability for real-time chats even in complex deployments.
|
||||
- 🎤 **Audio Transcription Issue with OpenAI Resolved**: OpenAI-based audio transcription now handles WebM and newer formats without error, ensuring seamless voice-to-text workflows every time.
|
||||
- 👉 **Message Input RTL Issue Fixed**: The chat message input now displays correctly for right-to-left languages, creating a flawless typing and reading experience for Arabic, Hebrew, and more.
|
||||
- 🀄 **Katex: Proper Rendering of Chinese Characters Next to Math**: Math formulas now render perfectly even when directly adjacent to Chinese (CJK) characters, improving visual clarity for multilingual teams and cross-language documents.
|
||||
- 🔂 **Duplicate Web Search URLs Eliminated**: Search results now reliably filter out URL duplicates, so your knowledge and search citations are always clean, trimmed, and easy to review.
|
||||
- 📄 **Markdown Rendering Fixed in Knowledge Bases**: Markdown is now displayed correctly within knowledge bases, enabling better formatting and clarity of information-rich files.
|
||||
- 🗂️ **LDAP Import/Loading Issue Resolved**: LDAP user imports process correctly, ensuring smooth onboarding and access without interruption.
|
||||
- 🌎 **Pinecone Batch Operations and Async Safety**: All Pinecone operations (batch insert, upsert, delete) now run efficiently and safely in an async environment, boosting performance and preventing slowdowns in large-scale RAG jobs.
|
||||
|
||||
## [0.6.7] - 2025-05-07
|
||||
|
||||
### Added
|
||||
|
||||
- 🌐 **Custom Azure TTS API URL Support Added**: You can now define a custom Azure Text-to-Speech endpoint—enabling flexibility for enterprise deployments and regional compliance.
|
||||
- ⚙️ **TOOL_SERVER_CONNECTIONS Environment Variable Suppor**: Easily configure and deploy tool servers via environment variables, streamlining setup and enabling faster enterprise provisioning.
|
||||
- 👥 **Enhanced OAuth Group Handling as String or List**: OAuth group data can now be passed as either a list or a comma-separated string, improving compatibility with varied identity provider formats and reducing onboarding friction.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🧠 **Embedding with Ollama Proxy Endpoints Restored**: Fixed an issue where missing API config broke embedding for proxied Ollama models—ensuring consistent performance and compatibility.
|
||||
- 🔐 **OIDC OAuth Login Issue Resolved**: Users can once again sign in seamlessly using OpenID Connect-based OAuth, eliminating login interruptions and improving reliability.
|
||||
- 📝 **Notes Feature Access Fixed for Non-Admins**: Fixed an issue preventing non-admin users from accessing the Notes feature, restoring full cross-role collaboration capabilities.
|
||||
- 🖼️ **Tika Loader Image Extraction Problem Resolved**: Ensured TikaLoader now processes 'extract_images' parameter correctly, restoring complete file extraction functionality in document workflows.
|
||||
- 🎨 **Automatic1111 Image Model Setting Applied Properly**: Fixed an issue where switching to a specific image model via the UI wasn’t reflected in generation, re-enabling full visual creativity control.
|
||||
- 🏷️ **Multiple XML Tags in Messages Now Parsed Correctly**: Fixed parsing issues when messages included multiple XML-style tags, ensuring clean and unbroken rendering of rich content in chats.
|
||||
- 🖌️ **OpenAI Image Generation Issues Resolved**: Resolved broken image output when using OpenAI’s image generation, ensuring fully functional visual creation workflows.
|
||||
- 🔎 **Tool Server Settings UI Privacy Restored**: Prevented restricted users from accessing tool server settings via search—restoring tight permissions control and safeguarding sensitive configurations.
|
||||
- 🎧 **WebM Audio Transcription Now Supported**: Fixed an issue where WebM files failed during audio transcription—these formats are now fully supported, ensuring smoother voice note workflows and broader file compatibility.
|
||||
|
||||
## [0.6.6] - 2025-05-05
|
||||
|
||||
### Added
|
||||
|
||||
- 📝 **AI-Enhanced Notes (With Audio Transcription)**: Effortlessly create notes, attach meeting or voice audio, and let the AI instantly enhance, summarize, or refine your notes using audio transcriptions—making your documentation smarter, cleaner, and more insightful with minimal effort.
|
||||
- 🔊 **Meeting Audio Recording & Import**: Seamlessly record audio from your meetings or capture screen audio and attach it to your notes—making it easier to revisit, annotate, and extract insights from important discussions.
|
||||
- 📁 **Import Markdown Notes Effortlessly**: Bring your existing knowledge library into Open WebUI by importing your Markdown notes, so you can leverage all advanced note management and AI features right away.
|
||||
- 👥 **Notes Permissions by User Group**: Fine-tune access and editing rights for notes based on user roles or groups, so you can delegate writing or restrict sensitive information as needed.
|
||||
- ☁️ **OneDrive & SharePoint Integration**: Keep your content in sync by connecting notes and files directly with OneDrive or SharePoint—unlocking fast enterprise import/export and seamless collaboration with your existing workflows.
|
||||
- 🗂️ **Paginated User List in Admin Panel**: Effortlessly manage and search through large teams via the new paginated user list—saving time and streamlining user administration in big organizations.
|
||||
- 🕹️ **Granular Chat Share & Export Permissions**: Enjoy enhanced control over who can share or export chats, enabling tighter governance and privacy in team and enterprise settings.
|
||||
- 🛑 **User Role Change Confirmation Dialog**: Reduce accidental privilege changes with a required confirmation step before updating user roles—improving security and preventing costly mistakes in team management.
|
||||
- 🚨 **Audit Log for Failed Login Attempts**: Quickly detect unauthorized access attempts or troubleshoot user login problems with detailed logs of failed authentication right in the audit trail.
|
||||
- 💡 **Dedicated 'Generate Title' Button for Chats**: Swiftly organize every conversation—tap the new button to let AI create relevant, clear titles for all your chats, saving time and reducing clutter.
|
||||
- 💬 **Notification Sound Always-On Option**: Take control of your notifications by setting sound alerts to always play—helping you stay on top of important updates in busy environments.
|
||||
- 🆔 **S3 File Tagging Support**: Uploaded files to S3 now include tags for better organization, searching, and integration with your file management policies.
|
||||
- 🛡️ **OAuth Blocked Groups Support**: Gain more control over group-based access by explicitly blocking specified OAuth groups—ideal for complex identity or security requirements.
|
||||
- 🚀 **Optimized Faster Web Search & Multi-Threaded Queries**: Enjoy dramatically faster web search and RAG (retrieval augmented generation) with revamped multi-threaded search—get richer, more accurate results in less time.
|
||||
- 🔍 **All-Knowledge Parallel Search**: Searches across your entire knowledge base now happen in parallel even in non-hybrid mode, speeding up responses and improving knowledge accuracy for every question.
|
||||
- 🌐 **New Firecrawl & Yacy Web Search Integrations**: Expand your world of information with two new advanced search engines—Firecrawl for deeper web insight and Yacy for decentralized, privacy-friendly search capabilities.
|
||||
- 🧠 **Configurable Docling OCR Engine & Language**: Use environment variables to fine-tune Docling OCR engine and supported languages for smarter, more tailored document extraction and RAG workflows.
|
||||
- 🗝️ **Enhanced Sentence Transformers Configuration**: Added new environment variables for easier set up and advanced customization of Sentence Transformers—ensuring best fit for your embedding needs.
|
||||
- 🌲 **Pinecone Vector Database Integration**: Index, search, and manage knowledge at enterprise scale with full native support for Pinecone as your vector database—effortlessly handle even the biggest document sets.
|
||||
- 🔄 **Automatic Requirements Installation for Tools & Functions**: Never worry about lost dependencies on restart—external function and tool requirements are now auto-installed at boot, ensuring tools always “just work.”
|
||||
- 🔒 **Automatic Sign-Out on Token Expiry**: Security is smarter—users are now automatically logged out if their authentication token expires, protecting sensitive content and ensuring compliance without disruption.
|
||||
- 🎬 **Automatic YouTube Embed Detection**: Paste YouTube links and see instant in-chat video embeds—no more manual embedding, making knowledge sharing and media consumption even easier for every team.
|
||||
- 🔄 **Expanded Language & Locale Support**: Translations for Danish, French, Russian, Traditional Chinese, Simplified Chinese, Thai, Catalan, German, and Korean have been upgraded, offering smoother, more natural user experiences across the platform.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🔒 **Tighter HTML Token Security**: HTML rendering is now restricted to admin-uploaded tokens only, reducing any risk of XSS and keeping your data safe.
|
||||
- 🔐 **Refined HTML Security and Token Handling**: Further hardened how HTML tokens and content are handled, guaranteeing even stronger resistance to security vulnerabilities and attacks.
|
||||
- 🔏 **Correct Model Usage with Ollama Proxy Prefixes**: Enhanced model reference handling so proxied models in Ollama always download and run correctly—even when using custom prefixes.
|
||||
- 📥 **Video File Upload Handling**: Prevented video files from being misclassified as text, fixing bugs with uploads and ensuring media files work as expected.
|
||||
- 🔄 **No More Dependent WebSocket Sequential Delays**: Streamlined WebSocket operation to prevent delays and maintain snappy real-time collaboration, especially in multi-user environments.
|
||||
- 🛠️ **More Robust Action Module Execution**: Multiple actions in a module now trigger as designed, increasing automation and scripting flexibility.
|
||||
- 📧 **Notification Webhooks**: Ensured that notification webhooks are always sent for user events, even when the user isn’t currently active.
|
||||
- 🗂️ **Smarter Knowledge Base Reindexing**: Knowledge reindexing continues even when corrupt or missing collections are encountered, keeping your search features running reliably.
|
||||
- 🏷️ **User Import with Profile Images**: When importing users, their profile images now come along—making onboarding and collaboration visually clearer from day one.
|
||||
- 💬 **OpenAI o-Series Universal Support**: All OpenAI o-series models are now seamlessly recognized and supported, unlocking more advanced capabilities and model choices for every workflow.
|
||||
|
||||
### Changed
|
||||
|
||||
- 📜 **Custom License Update & Contributor Agreement**: Open WebUI now operates under a custom license with Contributor License Agreement required by default—see https://docs.openwebui.com/license/ for details, ensuring sustainable open innovation for the community.
|
||||
- 🔨 **CUDA Docker Images Updated to 12.8**: Upgraded CUDA image support for faster, more compatible model inference and futureproof GPU performance in your AI infrastructure.
|
||||
- 🧱 **General Backend Refactoring for Reliability**: Continuous stability improvements streamline backend logic, reduce errors, and lay a stronger foundation for the next wave of feature releases—all under the hood for a more dependable WebUI.
|
||||
|
||||
## [0.6.5] - 2025-04-14
|
||||
|
||||
### Added
|
||||
|
||||
- 🛂 **Granular Voice Feature Permissions Per User Group**: Admins can now separately manage access to Speech-to-Text (record voice), Text-to-Speech (read aloud), and Tool Calls for each user group—giving teams tighter control over voice features and enhanced governance across roles.
|
||||
- 🗣️ **Toggle Voice Activity Detection (VAD) for Whisper STT**: New environment variable lets you enable/disable VAD filtering with built-in Whisper speech-to-text, giving you flexibility to optimize for different audio quality and response accuracy levels.
|
||||
- 📋 **Copy Formatted Response Mode**: You can now enable “Copy Formatted” in Settings > Interface to copy AI responses exactly as styled (with rich formatting, links, and structure preserved), making it faster and cleaner to paste into documents, emails, or reports.
|
||||
- ⚙️ **Backend Stability and Performance Enhancements**: General backend refactoring improves system resilience, consistency, and overall reliability—offering smoother performance across workflows whether chatting, generating media, or using external tools.
|
||||
- 🌎 **Translation Refinements Across Multiple Languages**: Updated translations deliver smoother language localization, clearer labels, and improved international usability throughout the UI—ensuring a better experience for non-English speakers.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛠️ **LDAP Login Reliability Restored**: Resolved a critical issue where some LDAP setups failed due to attribute parsing—ensuring consistent, secure, and seamless user authentication across enterprise deployments.
|
||||
- 🖼️ **Image Generation in Temporary Chats Now Works Properly**: Fixed a bug where image outputs weren’t generated during temporary chats—visual content can now be used reliably in all chat modes without interruptions.
|
||||
|
||||
## [0.6.4] - 2025-04-12
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛠️ **RAG_TEMPLATE Display Issue Resolved**: Fixed a formatting problem where the custom RAG_TEMPLATE wasn't correctly rendered in the interface—ensuring that custom retrieval prompts now appear exactly as intended for more reliable prompt engineering.
|
||||
|
||||
## [0.6.3] - 2025-04-12
|
||||
|
||||
### Added
|
||||
|
||||
- 🧪 **Auto-Artifact Detection Toggle**: Automatically detects artifacts in results—but now you can disable this behavior under advanced settings for full control.
|
||||
- 🖼️ **Widescreen Mode for Shared Chats**: Shared link conversations now support widescreen layouts—perfect for presentations or easier review across wider displays.
|
||||
- 🔁 **Reindex Knowledge Files on Demand**: Admins can now trigger reindexing of all knowledge files after changing embeddings—ensuring immediate alignment with new models for optimal RAG performance.
|
||||
- 📄 **OpenAPI YAML Format Support**: External tools can now use YAML-format OpenAPI specs—making integration simpler for developers familiar with YAML-based configurations.
|
||||
- 💬 **Message Content Copy Behavior**: Copy action now excludes 'details' tags—streamlining clipboard content when sharing or pasting summaries elsewhere.
|
||||
- 🧭 **Sougou Web Search Integration**: New search engine option added—enhancing global relevance and diversity of search sources for multilingual users.
|
||||
- 🧰 **Frontend Web Loader Engine Configuration**: Admins can now set preferred web loader engine for RAG workflows directly from the frontend—offering more control across setups.
|
||||
- 👥 **Multi-Model Chat Permission Control**: Admins can manage access to multi-model chats per user group—allowing tighter governance in team environments.
|
||||
- 🧱 **Persistent Configuration Can Be Disabled**: New environment variable lets advanced users and hosts turn off persistent configs—ideal for volatile or stateless deployments.
|
||||
- 🧠 **Elixir Code Highlighting Support**: Elixir syntax is now beautifully rendered in code blocks—perfect for developers using this language in AI or automation projects.
|
||||
- 🌐 **PWA External Manifest URL Support**: You can now define an external manifest.json—integrate Open WebUI seamlessly in managed or proxy-based PWA environments like Cloudflare Zero Trust.
|
||||
- 🧪 **Azure AI Speech-to-Text Provider Integration**: Easily transcribe large audio files (up to 200MB) with high accuracy using Microsoft's Azure STT—fully configurable in Audio Settings.
|
||||
- 🔏 **PKCE (Code Challenge Method) Support for OIDC**: Enhance your OIDC login security with Proof Key for Code Exchange—ideal for zero-trust and native client apps.
|
||||
- ✨ **General UI/UX Enhancements**: Numerous refinements across layout, styling, and tool interactions—reducing visual noise and improving overall usability across key workflows.
|
||||
- 🌍 **Translation Updates Across Multiple Languages**: Refined Catalan, Russian, Chinese (Simplified & Traditional), Hungarian, and Spanish translations for clearer navigation and instructions globally.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 💥 **Chat Completion Error with Missing Models Resolved**: Fixed internal server error when referencing a model that doesn’t exist—ensuring graceful fallback and clear error guidance.
|
||||
- 🔧 **Correct Knowledge Base Citations Restored**: Citations generated by RAG workflows now show accurate references—ensuring verifiability in outputs from sourced content.
|
||||
- 🎙️ **Broken OGG/WebM Audio Upload Handling for OpenAI Fixed**: Uploading OGG or WebM files now converts properly to WAV before transcription—restoring accurate AI speech recognition workflows.
|
||||
- 🔐 **Tool Server 'Session' Authentication Restored**: Previously broken session auth on external tool servers is now fully functional—ensuring secure and seamless access to connected tools.
|
||||
- 🌐 **Folder-Based Chat Rename Now Updates Correctly**: Renaming chats in folders now reflects instantly everywhere—improving chat organization and clarity.
|
||||
- 📜 **KaTeX Overflow Displays Fixed**: Math expressions now stay neatly within message bounds—preserving layout consistency even with long formulas.
|
||||
- 🚫 **Stopping Ongoing Chat Fixed**: You can now return to an active (ongoing) chat and stop generation at any time—ensuring full control over sessions.
|
||||
- 🔧 **TOOL_SERVERS / TOOL_SERVER_CONNECTIONS Indexing Issue Fixed**: Fixed a mismatch between tool lists and their access paths—restoring full function and preventing confusion in tool management.
|
||||
- 🔐 **LDAP Login Handles Multiple Emails**: When LDAP returns multiple email attributes, the first valid one is now used—ensuring login success and account consistency.
|
||||
- 🧩 **Model Visibility Toggle Fix**: Toggling model visibility now works even for untouched models—letting admins smoothly manage user access across base models.
|
||||
- ⚙️ **Cross-Origin manifest.json Now Loads Properly**: Compatibility issues with Cloudflare Zero Trust (and others) resolved, allowing manifest.json to load behind authenticated proxies.
|
||||
|
||||
### Changed
|
||||
|
||||
- 🔒 **Default Access Scopes Set to Private for All Resources**: Models, tools, and knowledge are now private by default when created—ensuring better baseline security and visibility controls.
|
||||
- 🧱 **General Backend Refactoring for Stability**: Numerous invisible improvements enhance backend scalability, security, and maintainability—powering upcoming features with a stronger foundation.
|
||||
- 🧩 **Stable Dependency Upgrades**: Updated key platform libraries—Chromadb (0.6.3), pgvector (0.4.0), Azure Identity (1.21.0), and Youtube Transcript API (1.0.3)—for improved compatibility, functionality, and security.
|
||||
|
||||
## [0.6.2] - 2025-04-06
|
||||
|
||||
### Added
|
||||
|
||||
- 🌍 **Improved Global Language Support**: Expanded and refined translations across multiple languages to enhance clarity and consistency for international users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛠️ **Accurate Tool Descriptions from OpenAPI Servers**: External tools now use full endpoint descriptions instead of summaries when generating tool specifications—helping AI models understand tool purpose more precisely and choose the right tool more accurately in tool workflows.
|
||||
- 🔧 **Precise Web Results Source Attribution**: Fixed a key issue where all web search results showed the same source ID—now each result gets its correct and distinct source, ensuring accurate citations and traceability.
|
||||
- 🔍 **Clean Web Search Retrieval**: Web search now retains only results from URLs where real content was successfully fetched—improving accuracy and removing empty or broken links from citations.
|
||||
- 🎵 **Audio File Upload Response Restored**: Resolved an issue where uploading audio files did not return valid responses, restoring smooth file handling for transcription and audio-based workflows.
|
||||
|
||||
### Changed
|
||||
|
||||
- 🧰 **General Backend Refactoring**: Multiple behind-the-scenes improvements streamline backend performance, reduce complexity, and ensure a more stable, maintainable system overall—making everything smoother without changing your workflow.
|
||||
|
||||
## [0.6.1] - 2025-04-05
|
||||
|
||||
### Added
|
||||
|
||||
- 🛠️ **Global Tool Servers Configuration**: Admins can now centrally configure global external tool servers from Admin Settings > Tools, allowing seamless sharing of tool integrations across all users without manual setup per user.
|
||||
- 🔐 **Direct Tool Usage Permission for Users**: Introduced a new user-level permission toggle that grants non-admin users access to direct external tools, empowering broader team collaboration while maintaining control.
|
||||
- 🧠 **Mistral OCR Content Extraction Support**: Added native support for Mistral OCR as a high-accuracy document loader, drastically improving text extraction from scanned documents in RAG workflows.
|
||||
- 🖼️ **Tools Indicator UI Redesign**: Enhanced message input now smartly displays both built-in and external tools via a unified dropdown, making it simpler and more intuitive to activate tools during conversations.
|
||||
- 📄 **RAG Prompt Improved and More Coherent**: Default RAG system prompt has been revised to be more clear and citation-focused—admins can leave the template field empty to use this new gold-standard prompt.
|
||||
- 🧰 **Performance & Developer Improvements**: Major internal restructuring of several tool-related components, simplifying styling and merging external/internal handling logic, resulting in better maintainability and performance.
|
||||
- 🌍 **Improved Translations**: Updated translations for Tibetan, Polish, Chinese (Simplified & Traditional), Arabic, Russian, Ukrainian, Dutch, Finnish, and French to improve clarity and consistency across the interface.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🔑 **External Tool Server API Key Bug Resolved**: Fixed a critical issue where authentication headers were not being sent when calling tools from external OpenAPI tool servers, ensuring full security and smooth tool operations.
|
||||
- 🚫 **Conditional Export Button Visibility**: UI now gracefully hides export buttons when there's nothing to export in models, prompts, tools, or functions, improving visual clarity and reducing confusion.
|
||||
- 🧪 **Hybrid Search Failure Recovery**: Resolved edge case in parallel hybrid search where empty or unindexed collections caused backend crashes—these are now cleanly skipped to ensure system stability.
|
||||
- 📂 **Admin Folder Deletion Fix**: Addressed an issue where folders created in the admin workspace couldn't be deleted, restoring full organizational flexibility for admins.
|
||||
- 🔐 **Improved Generic Error Feedback on Login**: Authentication errors now show simplified, non-revealing messages for privacy and improved UX, especially with federated logins.
|
||||
- 📝 **Tool Message with Images Improved**: Enhanced how tool-generated messages with image outputs are shown in chat, making them more readable and consistent with the overall UI design.
|
||||
- ⚙️ **Auto-Exclusion for Broken RAG Collections**: Auto-skips document collections that fail to fetch data or return "None", preventing silent errors and streamlining retrieval workflows.
|
||||
- 📝 **Docling Text File Handling Fix**: Fixed file parsing inconsistency that broke docling-based RAG functionality for certain plain text files, ensuring wider file compatibility.
|
||||
|
||||
## [0.6.0] - 2025-03-31
|
||||
|
||||
### Added
|
||||
|
||||
- 🧩 **External Tool Server Support via OpenAPI**: Connect Open WebUI to any OpenAPI-compatible REST server instantly—offering immediate integration with thousands of developer tools, SDKs, and SaaS systems for powerful extensibility. Learn more: https://github.com/open-webui/openapi-servers
|
||||
- 🛠️ **MCP Server Support via MCPO**: You can now convert and expose your internal MCP tools as interoperable OpenAPI HTTP servers within Open WebUI for seamless, plug-n-play AI toolchain creation. Learn more: https://github.com/open-webui/mcpo
|
||||
- 📨 **/messages Chat API Endpoint Support**: For power users building external AI systems, new endpoints allow precise control of messages asynchronously—feed long-running external responses into Open WebUI chats without coupling with the frontend.
|
||||
- 📝 **Client-Side PDF Generation**: PDF exports are now generated fully client-side for drastically improved output quality—perfect for saving conversations or documents.
|
||||
- 💼 **Enforced Temporary Chats Mode**: Admins can now enforce temporary chat sessions by default to align with stringent data retention and compliance requirements.
|
||||
- 🌍 **Public Resource Sharing Permission Controls**: Fine-grained user group permissions now allow enabling/disabling public sharing for models, knowledge, prompts, and tools—ideal for privacy, team control, and internal deployments.
|
||||
- 📦 **Custom pip Options for Tools/Functions**: You can now specify custom pip installation options with "PIP_OPTIONS", "PIP_PACKAGE_INDEX_OPTIONS" environment variables—improving compatibility, support for private indexes, and better control over Python environments.
|
||||
- 🔢 **Editable Message Counter**: You can now double-click the message count number and jump straight to editing the index—quickly navigate complex chats or regenerate specific messages precisely.
|
||||
- 🧠 **Embedding Prefix Support Added**: Add custom prefixes to your embeddings for instruct-style tokens, enabling stronger model alignment and more consistent RAG performance.
|
||||
- 🙈 **Ability to Hide Base Models**: Optionally hide base models from the UI, helping users streamline model visibility and limit access to only usable endpoints..
|
||||
- 📚 **Docling Content Extraction Support**: Open WebUI now supports Docling as a content extraction engine, enabling smarter and more accurate parsing of complex file formats—ideal for advanced document understanding and Retrieval-Augmented Generation (RAG) workflows.
|
||||
- 🗃️ **Redis Sentinel Support Added**: Enhance deployment redundancy with support for Redis Sentinel for highly available, failover-safe Redis-based caching or pub/sub.
|
||||
- 📚 **JSON Schema Format for Ollama**: Added support for defining the format using JSON schema in Ollama-compatible models, improving flexibility and validation of model outputs.
|
||||
- 🔍 **Chat Sidebar Search "Clear” Button**: Quickly clear search filters in chat sidebar using the new ✖️ button—streamline your chat navigation with one click.
|
||||
- 🗂️ **Auto-Focus + Enter Submit for Folder Name**: When creating a new folder, the system automatically enters rename mode with name preselected—simplifying your org workflow.
|
||||
- 🧱 **Markdown Alerts Rendering**: Blockquotes with syntax hinting (e.g. ⚠️, ℹ️, ✅) now render styled Markdown alert banners, making messages and documentation more visually structured.
|
||||
- 🔁 **Hybrid Search Runs in Parallel Now**: Hybrid (BM25 + embedding) search components now run in parallel—dramatically reducing response times and speeding up document retrieval.
|
||||
- 📋 **Cleaner UI for Tool Call Display**: Optimized the visual layout of called tools inside chat messages for better clarity and reduced visual clutter.
|
||||
- 🧪 **Playwright Timeout Now Configurable**: Default timeout for Playwright processes is now shorter and adjustable via environment variables—making web scraping more robust and tunable to environments.
|
||||
- 📈 **OpenTelemetry Support for Observability**: Open WebUI now integrates with OpenTelemetry, allowing you to connect with tools like Grafana, Jaeger, or Prometheus for detailed performance insights and real-time visibility—entirely opt-in and fully self-hosted. Even if enabled, no data is ever sent to us, ensuring your privacy and ownership over all telemetry data.
|
||||
- 🛠 **General UI Enhancements & UX Polish**: Numerous refinements across sidebar, code blocks, modal interactions, button alignment, scrollbar visibility, and folder behavior improve overall fluidity and usability of the interface.
|
||||
- 🧱 **General Backend Refactoring**: Numerous backend components have been refactored to improve stability, maintainability, and performance—ensuring a more consistent and reliable system across all features.
|
||||
- 🌍 **Internationalization Language Support Updates**: Added Estonian and Galician languages, improved Spanish (fully revised), Traditional Chinese, Simplified Chinese, Turkish, Catalan, Ukrainian, and German for a more localized and inclusive interface.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🧑💻 **Firefox Input Height Bug**: Text input in Firefox now maintains proper height, ensuring message boxes look consistent and behave predictably.
|
||||
- 🧾 **Tika Blank Line Bug**: PDFs processed with Apache Tika 3.1.0.0 no longer introduce excessive blank lines—improving RAG output quality and visual cleanliness.
|
||||
- 🧪 **CSV Loader Encoding Issues**: CSV files with unknown encodings now automatically detect character sets, resolving import errors in non-UTF-8 datasets.
|
||||
- ✅ **LDAP Auth Config Fix**: Path to certificate file is now optional for LDAP setups, fixing authentication trouble for users without preconfigured cert paths.
|
||||
- 📥 **File Deletion in Bypass Mode**: Resolved issue where files couldn’t be deleted from knowledge when “bypass embedding” mode was enabled.
|
||||
- 🧩 **Hybrid Search Result Sorting & Deduplication Fixed**: Fixed citation and sorting issues in RAG hybrid and reranker modes, ensuring retrieved documents are shown in correct order per score.
|
||||
- 🧷 **Model Export/Import Broken for a Single Model**: Fixed bug where individual models couldn’t be exported or re-imported, restoring full portability.
|
||||
- 📫 **Auth Redirect Fix**: Logged-in users are now routed properly without unnecessary login prompts when already authenticated.
|
||||
|
||||
### Changed
|
||||
|
||||
- 🧠 **Prompt Autocompletion Disabled By Default**: Autocomplete suggestions while typing are now disabled unless explicitly re-enabled in user preferences—reduces distractions while composing prompts for advanced users.
|
||||
- 🧾 **Normalize Citation Numbering**: Source citations now properly begin from "1" instead of "0"—improving consistency and professional presentation in AI outputs.
|
||||
- 📚 **Improved Error Handling from Pipelines**: Pipelines now show the actual returned error message from failed tasks rather than generic "Connection closed"—making debugging far more user-friendly.
|
||||
|
||||
### Removed
|
||||
|
||||
- 🧾 **ENABLE_AUDIT_LOGS Setting Removed**: Deprecated setting “ENABLE_AUDIT_LOGS” has been fully removed—now controlled via “AUDIT_LOG_LEVEL” instead.
|
||||
|
||||
## [0.5.20] - 2025-03-05
|
||||
|
||||
### Added
|
||||
|
||||
- **⚡ Toggle Code Execution On/Off**: You can now enable or disable code execution, providing more control over security, ensuring a safer and more customizable experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📜 Pinyin Keyboard Enter Key Now Works Properly**: Resolved an issue where the Enter key for Pinyin keyboards was not functioning as expected, ensuring seamless input for Chinese users.
|
||||
- **🖼️ Web Manifest Loading Issue Fixed**: Addressed inconsistencies with 'site.webmanifest', guaranteeing proper loading and representation of the app across different browsers and devices.
|
||||
- **📦 Non-Root Container Issue Resolved**: Fixed a critical issue where the UI failed to load correctly in non-root containers, ensuring reliable deployment in various environments.
|
||||
|
||||
## [0.5.19] - 2025-03-04
|
||||
|
||||
### Added
|
||||
|
||||
- **📊 Logit Bias Parameter Support**: Fine-tune conversation dynamics by adjusting the Logit Bias parameter directly in chat settings, giving you more control over model responses.
|
||||
- **⌨️ Customizable Enter Behavior**: You can now configure Enter to send messages only when combined with Ctrl (Ctrl+Enter) via Settings > Interface, preventing accidental message sends.
|
||||
- **📝 Collapsible Code Blocks**: Easily collapse long code blocks to declutter your chat, making it easier to focus on important details.
|
||||
- **🏷️ Tag Selector in Model Selector**: Quickly find and categorize models with the new tag filtering system in the Model Selector, streamlining model discovery.
|
||||
- **📈 Experimental Elasticsearch Vector DB Support**: Now supports Elasticsearch as a vector database, offering more flexibility for data retrieval in Retrieval-Augmented Generation (RAG) workflows.
|
||||
- **⚙️ General Reliability Enhancements**: Various stability improvements across the WebUI, ensuring a smoother, more consistent experience.
|
||||
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 "Stream" Hook Activation**: Fixed an issue where the "Stream" hook only worked when globally enabled, ensuring reliable real-time filtering.
|
||||
- **📧 LDAP Email Case Sensitivity**: Resolved an issue where LDAP login failed due to email case sensitivity mismatches, improving authentication reliability.
|
||||
- **💬 WebSocket Chat Event Registration**: Fixed a bug preventing chat event listeners from being registered upon sign-in, ensuring real-time updates work properly.
|
||||
|
||||
## [0.5.18] - 2025-02-27
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity.
|
||||
- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility.
|
||||
- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing.
|
||||
|
||||
## [0.5.17] - 2025-02-27
|
||||
|
||||
### Added
|
||||
|
||||
- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking.
|
||||
- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter).
|
||||
- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options.
|
||||
- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient.
|
||||
- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience.
|
||||
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors.
|
||||
- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data.
|
||||
- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant.
|
||||
|
||||
### Changed
|
||||
|
||||
- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI.
|
||||
|
||||
## [0.5.16] - 2025-02-20
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again.
|
||||
|
||||
## [0.5.15] - 2025-02-20
|
||||
|
||||
### Added
|
||||
|
||||
- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding.
|
||||
- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval.
|
||||
- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data.
|
||||
- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI.
|
||||
- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions.
|
||||
- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs.
|
||||
- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations.
|
||||
- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI.
|
||||
- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue.
|
||||
- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow.
|
||||
|
||||
## [0.5.14] - 2025-02-17
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability.
|
||||
|
||||
## [0.5.13] - 2025-02-17
|
||||
|
||||
### Added
|
||||
|
||||
- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results.
|
||||
- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval.
|
||||
- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts.
|
||||
- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat.
|
||||
- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history.
|
||||
- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data.
|
||||
- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow.
|
||||
|
||||
### Changed
|
||||
|
||||
- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity.
|
||||
- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages.
|
||||
|
||||
## [0.5.12] - 2025-02-13
|
||||
|
||||
### Added
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
## Our Pledge
|
||||
|
||||
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
As members, contributors, and leaders of this community, we pledge to make participation in our project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
|
||||
|
||||
## Why These Standards Are Important
|
||||
|
||||
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
|
||||
Projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
|
||||
|
||||
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ This approach ensures that disruptive behaviors are addressed swiftly and decisi
|
|||
|
||||
## Why Zero Tolerance Is Necessary
|
||||
|
||||
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
|
||||
Projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
|
||||
|
||||
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
|
||||
|
||||
|
|
|
|||
7
CONTRIBUTOR_LICENSE_AGREEMENT
Normal file
7
CONTRIBUTOR_LICENSE_AGREEMENT
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Open WebUI Contributor License Agreement
|
||||
|
||||
By submitting my contributions to Open WebUI, I grant Open WebUI full freedom to use my work in any way they choose, under any terms they like, both now and in the future. This approach helps ensure the project remains unified, flexible, and easy to maintain, while empowering Open WebUI to respond quickly to the needs of its users and the wider community.
|
||||
|
||||
Taking part in this process means my work can be seamlessly integrated and combined with others, ensuring longevity and adaptability for everyone who benefits from the Open WebUI project. This collaborative approach strengthens the project’s future and helps guarantee that improvements can always be shared and distributed in the most effective way possible.
|
||||
|
||||
**_To the fullest extent permitted by law, my contributions are provided on an “as is” basis, with no warranties or guarantees of any kind, and I disclaim any liability for any issues or damages arising from their use or incorporation into the project, regardless of the type of legal claim._**
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# Run with
|
||||
# caddy run --envfile ./example.env --config ./Caddyfile.localhost
|
||||
#
|
||||
# This is configured for
|
||||
# - Automatic HTTPS (even for localhost)
|
||||
# - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
|
||||
# - CORS
|
||||
# - HTTP Basic Auth API Tokens (uncomment basicauth section)
|
||||
|
||||
|
||||
# CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
|
||||
(cors-api) {
|
||||
@match-cors-api-preflight method OPTIONS
|
||||
handle @match-cors-api-preflight {
|
||||
header {
|
||||
Access-Control-Allow-Origin "{http.request.header.origin}"
|
||||
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
||||
Access-Control-Allow-Credentials "true"
|
||||
Access-Control-Max-Age "3600"
|
||||
defer
|
||||
}
|
||||
respond "" 204
|
||||
}
|
||||
|
||||
@match-cors-api-request {
|
||||
not {
|
||||
header Origin "{http.request.scheme}://{http.request.host}"
|
||||
}
|
||||
header Origin "{http.request.header.origin}"
|
||||
}
|
||||
handle @match-cors-api-request {
|
||||
header {
|
||||
Access-Control-Allow-Origin "{http.request.header.origin}"
|
||||
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
||||
Access-Control-Allow-Credentials "true"
|
||||
Access-Control-Max-Age "3600"
|
||||
defer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# replace localhost with example.com or whatever
|
||||
localhost {
|
||||
## HTTP Basic Auth
|
||||
## (uncomment to enable)
|
||||
# basicauth {
|
||||
# # see .example.env for how to generate tokens
|
||||
# {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
|
||||
# }
|
||||
|
||||
handle /api/* {
|
||||
# Comment to disable CORS
|
||||
import cors-api
|
||||
|
||||
reverse_proxy localhost:11434
|
||||
}
|
||||
|
||||
# Same-Origin Static Web Server
|
||||
file_server {
|
||||
root ./build/
|
||||
}
|
||||
}
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
ARG USE_CUDA=false
|
||||
ARG USE_OLLAMA=false
|
||||
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
||||
ARG USE_CUDA_VER=cu121
|
||||
ARG USE_CUDA_VER=cu128
|
||||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
|
|
@ -26,6 +26,9 @@ ARG BUILD_HASH
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
# to store git revision in build
|
||||
RUN apk add --no-cache git
|
||||
|
||||
COPY package.json package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
|
|
@ -132,7 +135,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
|||
# install python dependencies
|
||||
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
||||
|
||||
RUN pip3 install uv && \
|
||||
RUN pip3 install --no-cache-dir uv && \
|
||||
if [ "$USE_CUDA" = "true" ]; then \
|
||||
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
||||
|
|
|
|||
8
LICENSE
8
LICENSE
|
|
@ -1,4 +1,4 @@
|
|||
Copyright (c) 2023-2025 Timothy Jaeryang Baek
|
||||
Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI)
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
|
|
@ -15,6 +15,12 @@ modification, are permitted provided that the following conditions are met:
|
|||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
4. Notwithstanding any other provision of this License, and as a material condition of the rights granted herein, licensees are strictly prohibited from altering, removing, obscuring, or replacing any "Open WebUI" branding, including but not limited to the name, logo, or any visual, textual, or symbolic identifiers that distinguish the software and its interfaces, in any deployment or distribution, regardless of the number of users, except as explicitly set forth in Clauses 5 and 6 below.
|
||||
|
||||
5. The branding restriction enumerated in Clause 4 shall not apply in the following limited circumstances: (i) deployments or distributions where the total number of end users (defined as individual natural persons with direct access to the application) does not exceed fifty (50) within any rolling thirty (30) day period; (ii) cases in which the licensee is an official contributor to the codebase—with a substantive code change successfully merged into the main branch of the official codebase maintained by the copyright holder—who has obtained specific prior written permission for branding adjustment from the copyright holder; or (iii) where the licensee has obtained a duly executed enterprise license expressly permitting such modification. For all other cases, any removal or alteration of the "Open WebUI" branding shall constitute a material breach of license.
|
||||
|
||||
6. All code, modifications, or derivative works incorporated into this project prior to the incorporation of this branding clause remain licensed under the BSD 3-Clause License, and prior contributors retain all BSD-3 rights therein; if any such contributor requests the removal of their BSD-3-licensed code, the copyright holder will do so, and any replacement code will be licensed under the project's primary license then in effect. By contributing after this clause's adoption, you agree to the project's Contributor License Agreement (CLA) and to these updated terms for all new contributions.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
|
|
|
|||
43
README.md
43
README.md
|
|
@ -7,16 +7,20 @@
|
|||

|
||||

|
||||

|
||||

|
||||
[](https://discord.gg/5rJgQTnV4s)
|
||||
[](https://github.com/sponsors/tjbck)
|
||||
|
||||
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
||||
|
||||
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
||||
|
||||

|
||||
|
||||
> [!TIP]
|
||||
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
|
||||
>
|
||||
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
|
||||
|
||||
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
||||
|
||||
## Key Features of Open WebUI ⭐
|
||||
|
||||
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
|
||||
|
|
@ -57,9 +61,36 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
|||
|
||||
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
|
||||
|
||||
## 🔗 Also Check Out Open WebUI Community!
|
||||
## Sponsors 🙌
|
||||
|
||||
Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀
|
||||
#### Emerald
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://n8n.io/" target="_blank">
|
||||
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
N8N • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://warp.dev/open-webui" target="_blank">
|
||||
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://warp.dev/open-webui">Warp</a> • The intelligent terminal for developers
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
|
||||
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
|
||||
|
||||
## How to Install 🚀
|
||||
|
||||
|
|
@ -201,7 +232,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http
|
|||
|
||||
## License 📜
|
||||
|
||||
This project is licensed under the [BSD-3-Clause License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
|
||||
This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄
|
||||
|
||||
## Support 💬
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,15 @@ def serve(
|
|||
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
|
||||
|
||||
import open_webui.main # we need set environment variables before importing main
|
||||
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
||||
|
||||
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
|
||||
uvicorn.run(
|
||||
"open_webui.main:app",
|
||||
host=host,
|
||||
port=port,
|
||||
forwarded_allow_ips="*",
|
||||
workers=UVICORN_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -31,6 +31,7 @@ class ERROR_MESSAGES(str, Enum):
|
|||
USERNAME_TAKEN = (
|
||||
"Uh-oh! This username is already registered. Please choose another username."
|
||||
)
|
||||
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
||||
|
||||
|
|
|
|||
|
|
@ -65,10 +65,8 @@ except Exception:
|
|||
# LOGGING
|
||||
####################################
|
||||
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
|
@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
|||
|
||||
if "cuda_error" in locals():
|
||||
log.exception(cuda_error)
|
||||
del cuda_error
|
||||
|
||||
log_sources = [
|
||||
"AUDIO",
|
||||
|
|
@ -100,19 +99,19 @@ SRC_LOG_LEVELS = {}
|
|||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
|
||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
|
||||
|
||||
####################################
|
||||
# ENV (dev,test,prod)
|
||||
|
|
@ -130,7 +129,6 @@ else:
|
|||
except Exception:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
|
||||
|
||||
|
|
@ -161,7 +159,6 @@ try:
|
|||
except Exception:
|
||||
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
||||
|
||||
|
||||
# Convert markdown content to HTML
|
||||
html_content = markdown.markdown(changelog_content)
|
||||
|
||||
|
|
@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
|
|||
|
||||
changelog_json[version_number] = version_data
|
||||
|
||||
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
####################################
|
||||
|
|
@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
|
|||
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_BUILD_HASH
|
||||
####################################
|
||||
|
|
@ -244,7 +239,6 @@ if FROM_INIT_PY:
|
|||
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
|
||||
|
||||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
|
||||
|
||||
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
|
||||
|
|
@ -256,7 +250,6 @@ if FROM_INIT_PY:
|
|||
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
|
||||
).resolve()
|
||||
|
||||
|
||||
####################################
|
||||
# Database
|
||||
####################################
|
||||
|
|
@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
|
|||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
ENABLE_REALTIME_CHAT_SAVE = (
|
||||
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
|
||||
)
|
||||
|
|
@ -330,7 +322,23 @@ ENABLE_REALTIME_CHAT_SAVE = (
|
|||
# REDIS
|
||||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
|
||||
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
|
||||
|
||||
####################################
|
||||
# UVICORN WORKERS
|
||||
####################################
|
||||
|
||||
# Number of uvicorn worker processes for handling requests
|
||||
UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1")
|
||||
try:
|
||||
UVICORN_WORKERS = int(UVICORN_WORKERS)
|
||||
if UVICORN_WORKERS < 1:
|
||||
UVICORN_WORKERS = 1
|
||||
except ValueError:
|
||||
UVICORN_WORKERS = 1
|
||||
log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}")
|
||||
|
||||
####################################
|
||||
# WEBUI_AUTH (Required for security)
|
||||
|
|
@ -341,11 +349,19 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
|||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||
)
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
|
||||
)
|
||||
|
||||
|
||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||
)
|
||||
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get(
|
||||
"WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None
|
||||
)
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
|
|
@ -385,6 +401,11 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
|||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
|
||||
|
||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||
|
||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
|
|
@ -396,19 +417,88 @@ else:
|
|||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
||||
|
||||
AIOHTTP_CLIENT_SESSION_SSL = (
|
||||
os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true"
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
|
||||
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10"
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
|
||||
)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
|
||||
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# SENTENCE TRANSFORMERS
|
||||
####################################
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
|
||||
if SENTENCE_TRANSFORMERS_BACKEND == "":
|
||||
SENTENCE_TRANSFORMERS_BACKEND = "torch"
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
|
||||
else:
|
||||
try:
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS
|
||||
)
|
||||
except Exception:
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
|
||||
else:
|
||||
try:
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
|
||||
)
|
||||
except Exception:
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
|
|
@ -418,3 +508,56 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
|||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
||||
# Comma separated list for urls to exclude from audit
|
||||
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||
","
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
||||
|
||||
####################################
|
||||
# OPENTELEMETRY
|
||||
####################################
|
||||
|
||||
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
|
||||
)
|
||||
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
|
||||
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
|
||||
"OTEL_RESOURCE_ATTRIBUTES", ""
|
||||
) # e.g. key1=val1,key2=val2
|
||||
OTEL_TRACES_SAMPLER = os.environ.get(
|
||||
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
|
||||
).lower()
|
||||
|
||||
####################################
|
||||
# TOOLS/FUNCTIONS PIP OPTIONS
|
||||
####################################
|
||||
|
||||
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
|
||||
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
|
||||
|
||||
|
||||
####################################
|
||||
# PROGRESSIVE WEB APP OPTIONS
|
||||
####################################
|
||||
|
||||
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
|
|
@ -27,7 +28,10 @@ from open_webui.socket.main import (
|
|||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
|
@ -52,12 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in request.app.state.FUNCTIONS:
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = request.app.state.FUNCTIONS[pipe_id]
|
||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
|
|
@ -76,11 +75,13 @@ async def get_function_models(request):
|
|||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||
sub_pipes = await function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
|
|
@ -220,6 +221,9 @@ async def generate_function_chat_completion(
|
|||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__chat_id__": metadata.get("chat_id", None),
|
||||
"__session_id__": metadata.get("session_id", None),
|
||||
"__message_id__": metadata.get("message_id", None),
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
|
|
@ -249,8 +253,13 @@ async def generate_function_chat_completion(
|
|||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(
|
||||
system, form_data, metadata, user
|
||||
)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
|||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url, unquote_password=True)
|
||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
|
|
@ -51,7 +51,7 @@ def register_connection(db_url):
|
|||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_password=True)
|
||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(**connection)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,33 @@
|
|||
"""Add note table
|
||||
|
||||
Revision ID: 9f0c9cd09105
|
||||
Revises: 3781e22d8b01
|
||||
Create Date: 2025-05-03 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "9f0c9cd09105"
|
||||
down_revision = "3781e22d8b01"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"note",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("note")
|
||||
|
|
@ -129,12 +129,16 @@ class AuthsTable:
|
|||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||
if auth:
|
||||
if verify_password(password, auth.password):
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
|
@ -155,8 +159,8 @@ class AuthsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_trusted_header: {email}")
|
||||
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_email: {email}")
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -5,7 +6,7 @@ from typing import Optional
|
|||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
|
|
@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
|
|||
# Chat DB Schema
|
||||
####################
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chat"
|
||||
|
|
@ -373,22 +377,47 @@ class ChatTable:
|
|||
return False
|
||||
|
||||
def get_archived_chat_list_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
self,
|
||||
user_id: str,
|
||||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, archived=True)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
.all()
|
||||
)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by and direction and getattr(Chat, order_by):
|
||||
if direction.lower() == "asc":
|
||||
query = query.order_by(getattr(Chat, order_by).asc())
|
||||
elif direction.lower() == "desc":
|
||||
query = query.order_by(getattr(Chat, order_by).desc())
|
||||
else:
|
||||
raise ValueError("Invalid direction for ordering")
|
||||
else:
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
include_archived: bool = False,
|
||||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
|
|
@ -397,7 +426,23 @@ class ChatTable:
|
|||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by and direction and getattr(Chat, order_by):
|
||||
if direction.lower() == "asc":
|
||||
query = query.order_by(getattr(Chat, order_by).asc())
|
||||
elif direction.lower() == "desc":
|
||||
query = query.order_by(getattr(Chat, order_by).desc())
|
||||
else:
|
||||
raise ValueError("Invalid direction for ordering")
|
||||
else:
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
|
|
@ -432,7 +477,7 @@ class ChatTable:
|
|||
|
||||
all_chats = query.all()
|
||||
|
||||
# result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
|
||||
# result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
{
|
||||
|
|
@ -538,7 +583,9 @@ class ChatTable:
|
|||
search_text = search_text.lower().strip()
|
||||
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
return self.get_chat_list_by_user_id(
|
||||
user_id, include_archived, filter={}, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
search_text_words = search_text.split(" ")
|
||||
|
||||
|
|
@ -670,7 +717,7 @@ class ChatTable:
|
|||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
print(len(all_chats))
|
||||
log.info(f"The number of chats: {len(all_chats)}")
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
|
@ -731,7 +778,7 @@ class ChatTable:
|
|||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
print(db.bind.dialect.name)
|
||||
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
|
|
@ -752,7 +799,7 @@ class ChatTable:
|
|||
)
|
||||
|
||||
all_chats = query.all()
|
||||
print("all_chats", all_chats)
|
||||
log.debug(f"all_chats: {all_chats}")
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
|
|
@ -810,7 +857,7 @@ class ChatTable:
|
|||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
print(f"Count of chats for tag '{tag_name}':", count)
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
|
||||
return count
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class FeedbackTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new feedback: {e}")
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class FilesTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from open_webui.models.chats import Chats
|
|||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -82,7 +84,7 @@ class FolderTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new folder: {e}")
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
|
|
@ -234,15 +236,18 @@ class FolderTable:
|
|||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
def delete_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str, delete_chats=True
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return False
|
||||
|
||||
# Delete all chats in the folder
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
||||
if delete_chats:
|
||||
# Delete all chats in the folder
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
||||
|
||||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
|
|
@ -250,9 +255,11 @@ class FolderTable:
|
|||
folder.id, user_id
|
||||
)
|
||||
for folder_child in folder_children:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user_id, folder_child.id
|
||||
)
|
||||
if delete_chats:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user_id, folder_child.id
|
||||
)
|
||||
|
||||
delete_children(folder_child)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
||||
|
|
|
|||
|
|
@ -105,9 +105,57 @@ class FunctionsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
return None
|
||||
|
||||
def sync_functions(
|
||||
self, user_id: str, functions: list[FunctionModel]
|
||||
) -> list[FunctionModel]:
|
||||
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Get existing functions
|
||||
existing_functions = db.query(Function).all()
|
||||
existing_ids = {func.id for func in existing_functions}
|
||||
|
||||
# Prepare a set of new function IDs
|
||||
new_function_ids = {func.id for func in functions}
|
||||
|
||||
# Update or insert functions
|
||||
for func in functions:
|
||||
if func.id in existing_ids:
|
||||
db.query(Function).filter_by(id=func.id).update(
|
||||
{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_func = Function(
|
||||
**{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_func)
|
||||
|
||||
# Remove functions that are no longer present
|
||||
for func in existing_functions:
|
||||
if func.id not in new_function_ids:
|
||||
db.delete(func)
|
||||
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
FunctionModel.model_validate(func)
|
||||
for func in db.query(Function).all()
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing functions for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
|
@ -170,7 +218,7 @@ class FunctionsTable:
|
|||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
|
|
@ -202,7 +250,9 @@ class FunctionsTable:
|
|||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
|
|
@ -225,7 +275,9 @@ class FunctionsTable:
|
|||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
|
|
|
|||
|
|
@ -207,5 +207,43 @@ class GroupTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_user_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str]
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
# Remove user from groups not in the new list
|
||||
existing_groups = self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in existing_groups:
|
||||
if group.id not in group_ids:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
for group in groups:
|
||||
if user_id not in group.user_ids:
|
||||
group.user_ids.append(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
Groups = GroupTable()
|
||||
|
|
|
|||
|
|
@ -63,14 +63,15 @@ class MemoriesTable:
|
|||
else:
|
||||
return None
|
||||
|
||||
def update_memory_by_id(
|
||||
def update_memory_by_id_and_user_id(
|
||||
self,
|
||||
id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id).update(
|
||||
db.query(Memory).filter_by(id=id, user_id=user_id).update(
|
||||
{"content": content, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
|
|
|
|||
5
backend/open_webui/models/models.py
Normal file → Executable file
5
backend/open_webui/models/models.py
Normal file → Executable file
|
|
@ -166,7 +166,7 @@ class ModelsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
|
|
@ -246,8 +246,7 @@ class ModelsTable:
|
|||
db.refresh(model)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
|
|
|
|||
135
backend/open_webui/models/notes.py
Normal file
135
backend/open_webui/models/notes.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Note DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Note(Base):
|
||||
__tablename__ = "note"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
|
||||
title = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class NoteModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
title: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class NoteForm(BaseModel):
|
||||
title: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class NoteUserResponse(NoteModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class NoteTable:
|
||||
def insert_new_note(
|
||||
self,
|
||||
form_data: NoteForm,
|
||||
user_id: str,
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
note = NoteModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
**form_data.model_dump(),
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
}
|
||||
)
|
||||
|
||||
new_note = Note(**note.model_dump())
|
||||
|
||||
db.add(new_note)
|
||||
db.commit()
|
||||
return note
|
||||
|
||||
def get_notes(self) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def get_notes_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[NoteModel]:
|
||||
notes = self.get_notes()
|
||||
return [
|
||||
note
|
||||
for note in notes
|
||||
if note.user_id == user_id
|
||||
or has_access(user_id, permission, note.access_control)
|
||||
]
|
||||
|
||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
if not note:
|
||||
return None
|
||||
|
||||
note.title = form_data.title
|
||||
note.data = form_data.data
|
||||
note.meta = form_data.meta
|
||||
note.access_control = form_data.access_control
|
||||
note.updated_at = int(time.time_ns())
|
||||
|
||||
db.commit()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def delete_note_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Notes = NoteTable()
|
||||
|
|
@ -61,7 +61,7 @@ class TagTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new tag: {e}")
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class ToolsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
|
|
@ -175,7 +175,7 @@ class ToolsTable:
|
|||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
|
|
@ -204,7 +204,9 @@ class ToolsTable:
|
|||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
|
|
@ -227,7 +229,9 @@ class ToolsTable:
|
|||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from open_webui.models.groups import Groups
|
|||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import or_
|
||||
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
|
|
@ -67,6 +69,11 @@ class UserModel(BaseModel):
|
|||
####################
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserModel]
|
||||
total: int
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
|
@ -160,11 +167,63 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
def get_users(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> list[UserModel]:
|
||||
self,
|
||||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> UserListResponse:
|
||||
with get_db() as db:
|
||||
query = db.query(User)
|
||||
|
||||
query = db.query(User).order_by(User.created_at.desc())
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
)
|
||||
)
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.name.asc())
|
||||
else:
|
||||
query = query.order_by(User.name.desc())
|
||||
elif order_by == "email":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.email.asc())
|
||||
else:
|
||||
query = query.order_by(User.email.desc())
|
||||
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
elif order_by == "last_active_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.last_active_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.last_active_at.desc())
|
||||
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.updated_at.desc())
|
||||
elif order_by == "role":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.role.asc())
|
||||
else:
|
||||
query = query.order_by(User.role.desc())
|
||||
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
|
|
@ -172,8 +231,10 @@ class UsersTable:
|
|||
query = query.limit(limit)
|
||||
|
||||
users = query.all()
|
||||
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
return {
|
||||
"users": [UserModel.model_validate(user) for user in users],
|
||||
"total": db.query(User).count(),
|
||||
}
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
|
|
@ -330,5 +391,13 @@ class UsersTable:
|
|||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.id for user in users]
|
||||
|
||||
def get_super_admin_user(self) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(role="admin").first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
Users = UsersTable()
|
||||
|
|
|
|||
251
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
251
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatalabMarkerLoader:
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_key: str,
|
||||
langs: Optional[str] = None,
|
||||
use_llm: bool = False,
|
||||
skip_cache: bool = False,
|
||||
force_ocr: bool = False,
|
||||
paginate: bool = False,
|
||||
strip_existing_ocr: bool = False,
|
||||
disable_image_extraction: bool = False,
|
||||
output_format: str = None,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_key = api_key
|
||||
self.langs = langs
|
||||
self.use_llm = use_llm
|
||||
self.skip_cache = skip_cache
|
||||
self.force_ocr = force_ocr
|
||||
self.paginate = paginate
|
||||
self.strip_existing_ocr = strip_existing_ocr
|
||||
self.disable_image_extraction = disable_image_extraction
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
ext = filename.rsplit(".", 1)[-1].lower()
|
||||
mime_map = {
|
||||
"pdf": "application/pdf",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"doc": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"html": "text/html",
|
||||
"epub": "application/epub+zip",
|
||||
"png": "image/png",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"tiff": "image/tiff",
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
||||
return result
|
||||
except requests.HTTPError as e:
|
||||
log.error(f"Error checking Marker request status: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to check Marker request: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
url = "https://www.datalab.to/api/v1/marker"
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
form_data = {
|
||||
"langs": self.langs,
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
url, data=form_data, files=files, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
if not check_url:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
|
||||
)
|
||||
|
||||
for _ in range(300): # Up to 10 minutes
|
||||
time.sleep(2)
|
||||
try:
|
||||
poll_response = requests.get(check_url, headers=headers)
|
||||
poll_response.raise_for_status()
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
|
||||
if status_val == "complete":
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
)
|
||||
|
||||
content_key = self.output_format.lower()
|
||||
raw_content = poll_result.get(content_key)
|
||||
|
||||
if content_key == "json":
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif content_key in {"markdown", "html"}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported output format: {self.output_format}",
|
||||
)
|
||||
|
||||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Datalab Marker returned empty content",
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(content_key, "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(full_text)
|
||||
log.info(f"Saved Marker output to: {output_path}")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to write marker output to disk: {e}")
|
||||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": poll_result.get("output_format", self.output_format),
|
||||
"page_count": poll_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
}
|
||||
|
||||
images = poll_result.get("images", {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
metadata[k] = json.dumps(v)
|
||||
elif v is None:
|
||||
metadata[k] = ""
|
||||
|
||||
return [Document(page_content=full_text, metadata=metadata)]
|
||||
58
backend/open_webui/retrieval/loaders/external_document.py
Normal file
58
backend/open_webui/retrieval/loaders/external_document.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalDocumentLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
url: str,
|
||||
api_key: str,
|
||||
mime_type=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
headers = {}
|
||||
if self.mime_type is not None:
|
||||
headers["Content-Type"] = self.mime_type
|
||||
|
||||
if self.api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
url = self.url
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
|
||||
r = requests.put(f"{url}/process", data=data, headers=headers)
|
||||
|
||||
if r.ok:
|
||||
res = r.json()
|
||||
|
||||
if res:
|
||||
return [
|
||||
Document(
|
||||
page_content=res.get("page_content"),
|
||||
metadata=res.get("metadata"),
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise Exception("Error loading document: No content returned")
|
||||
else:
|
||||
raise Exception(f"Error loading document: {r.status_code} {r.text}")
|
||||
53
backend/open_webui/retrieval/loaders/external_web.py
Normal file
53
backend/open_webui/retrieval/loaders/external_web.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalWebLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
external_url: str,
|
||||
external_api_key: str,
|
||||
continue_on_failure: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.external_url = external_url
|
||||
self.external_api_key = external_api_key
|
||||
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
},
|
||||
json={
|
||||
"urls": urls,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
for result in results:
|
||||
yield Document(
|
||||
page_content=result.get("page_content", ""),
|
||||
metadata=result.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
|
@ -4,6 +4,7 @@ import ftfy
|
|||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
|
|
@ -19,6 +20,13 @@ from langchain_community.document_loaders import (
|
|||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
|
|
@ -76,15 +84,18 @@ known_source_ext = [
|
|||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
]
|
||||
|
||||
|
||||
class TikaLoader:
|
||||
def __init__(self, url, file_path, mime_type=None):
|
||||
def __init__(self, url, file_path, mime_type=None, extract_images=None):
|
||||
self.url = url
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
self.extract_images = extract_images
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
|
@ -94,6 +105,9 @@ class TikaLoader:
|
|||
else:
|
||||
headers = {}
|
||||
|
||||
if self.extract_images == True:
|
||||
headers["X-Tika-PDFextractInlineImages"] = "true"
|
||||
|
||||
endpoint = self.url
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
|
|
@ -103,7 +117,7 @@ class TikaLoader:
|
|||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
|
|
@ -115,6 +129,68 @@ class TikaLoader:
|
|||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(self, url, file_path=None, mime_type=None, params=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {
|
||||
"files": (
|
||||
self.file_path,
|
||||
f,
|
||||
self.mime_type or "application/octet-stream",
|
||||
)
|
||||
}
|
||||
|
||||
params = {
|
||||
"image_export_mode": "placeholder",
|
||||
"table_mode": "accurate",
|
||||
}
|
||||
|
||||
if self.params:
|
||||
if self.params.get("do_picture_classification"):
|
||||
params["do_picture_classification"] = self.params.get(
|
||||
"do_picture_classification"
|
||||
)
|
||||
|
||||
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
|
||||
params["ocr_engine"] = self.params.get("ocr_engine")
|
||||
params["ocr_lang"] = [
|
||||
lang.strip()
|
||||
for lang in self.params.get("ocr_lang").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
endpoint = f"{self.url}/v1alpha/convert/file"
|
||||
r = requests.post(endpoint, files=files, data=params)
|
||||
|
||||
if r.ok:
|
||||
result = r.json()
|
||||
document_data = result.get("document", {})
|
||||
text = document_data.get("md_content", "<No text content found>")
|
||||
|
||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||
|
||||
log.debug("Docling extracted text: %s", text)
|
||||
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
else:
|
||||
error_msg = f"Error calling Docling API: {r.reason}"
|
||||
if r.text:
|
||||
try:
|
||||
error_data = r.json()
|
||||
if "detail" in error_data:
|
||||
error_msg += f" - {error_data['detail']}"
|
||||
except Exception:
|
||||
error_msg += f" - {r.text}"
|
||||
raise Exception(f"Error calling Docling: {error_msg}")
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
self.engine = engine
|
||||
|
|
@ -133,27 +209,140 @@ class Loader:
|
|||
for doc in docs
|
||||
]
|
||||
|
||||
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
||||
return file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
)
|
||||
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
|
||||
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
if (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
||||
):
|
||||
loader = ExternalDocumentLoader(
|
||||
file_path=file_path,
|
||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TikaLoader(
|
||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
)
|
||||
elif (
|
||||
self.engine == "datalab_marker"
|
||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
||||
and file_ext
|
||||
in [
|
||||
"pdf",
|
||||
"xls",
|
||||
"xlsx",
|
||||
"ods",
|
||||
"doc",
|
||||
"docx",
|
||||
"odt",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"odp",
|
||||
"html",
|
||||
"epub",
|
||||
"png",
|
||||
"jpeg",
|
||||
"jpg",
|
||||
"webp",
|
||||
"gif",
|
||||
"tiff",
|
||||
]
|
||||
):
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
||||
strip_existing_ocr=self.kwargs.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
||||
),
|
||||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params={
|
||||
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
"do_picture_classification": self.kwargs.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION"
|
||||
),
|
||||
},
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
)
|
||||
elif (
|
||||
self.engine == "mistral_ocr"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
||||
)
|
||||
elif (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
|
||||
)
|
||||
elif file_ext == "csv":
|
||||
loader = CSVLoader(file_path)
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "rst":
|
||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
||||
elif file_ext == "xml":
|
||||
|
|
@ -182,9 +371,7 @@ class Loader:
|
|||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
elif self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
|
|
|||
633
backend/open_webui/retrieval/loaders/mistral.py
Normal file
633
backend/open_webui/retrieval/loaders/mistral.py
Normal file
|
|
@ -0,0 +1,633 @@
|
|||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MistralLoader:
|
||||
"""
|
||||
Enhanced Mistral OCR loader with both sync and async support.
|
||||
Loads documents by processing them through the Mistral OCR API.
|
||||
"""
|
||||
|
||||
BASE_API_URL = "https://api.mistral.ai/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
file_path: str,
|
||||
timeout: int = 300, # 5 minutes default
|
||||
max_retries: int = 3,
|
||||
enable_debug_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the loader with enhanced features.
|
||||
|
||||
Args:
|
||||
api_key: Your Mistral API key.
|
||||
file_path: The local path to the PDF file to process.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found at {file_path}")
|
||||
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.debug = enable_debug_logging
|
||||
|
||||
# Pre-compute file info for performance
|
||||
self.file_name = os.path.basename(file_path)
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0",
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
"""Conditional debug logging for performance."""
|
||||
if self.debug:
|
||||
log.debug(message, *args)
|
||||
|
||||
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
|
||||
"""Checks response status and returns JSON content."""
|
||||
try:
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
|
||||
# Handle potential empty responses for certain successful requests (e.g., DELETE)
|
||||
if response.status_code == 204 or not response.content:
|
||||
return {} # Return empty dict if no content
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
log.error(f"Request exception occurred: {req_err}")
|
||||
raise
|
||||
except ValueError as json_err: # Includes JSONDecodeError
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
raise # Re-raise after logging
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
raise
|
||||
|
||||
def _retry_request_sync(self, request_func, *args, **kwargs):
|
||||
"""Synchronous retry logic with exponential backoff."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return request_func(*args, **kwargs)
|
||||
except (requests.exceptions.RequestException, Exception) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
async def _retry_request_async(self, request_func, *args, **kwargs):
|
||||
"""Async retry logic with exponential backoff."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await request_func(*args, **kwargs)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""Uploads the file to Mistral for OCR processing (sync version)."""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
file_name = os.path.basename(self.file_path)
|
||||
|
||||
def upload_request():
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
data=writer,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
log.info("Signed URL received.")
|
||||
return signed_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
ocr_headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
|
||||
return ocr_response
|
||||
|
||||
return await self._retry_request_async(ocr_request)
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}"
|
||||
|
||||
try:
|
||||
response = requests.delete(url, headers=self.headers, timeout=30)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
async with session.delete(
|
||||
url=f"{self.BASE_API_URL}/files/{file_id}",
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=30
|
||||
), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
"""Context manager for HTTP session with optimized settings."""
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=10, # Total connection limit
|
||||
limit_per_host=5, # Per-host connection limit
|
||||
ttl_dns_cache=300, # DNS cache TTL
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found", metadata={"error": "no_pages"}
|
||||
)
|
||||
]
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
skipped_pages = 0
|
||||
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is not None and page_index is not None:
|
||||
# Clean up content efficiently
|
||||
cleaned_content = (
|
||||
page_content.strip()
|
||||
if isinstance(page_content, str)
|
||||
else str(page_content)
|
||||
)
|
||||
|
||||
if cleaned_content: # Only add non-empty pages
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index
|
||||
+ 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
else:
|
||||
skipped_pages += 1
|
||||
self._debug_log(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
metadata={"error": "no_valid_pages", "total_pages": total_pages},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
|
||||
Synchronous version for backward compatibility.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Upload file
|
||||
file_id = self._upload_file()
|
||||
|
||||
# 2. Get Signed URL
|
||||
signed_url = self._get_signed_url(file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = self._process_ocr(signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Delete file (attempt even if prior steps failed after upload)
|
||||
if file_id:
|
||||
try:
|
||||
self._delete_file(file_id)
|
||||
except Exception as del_e:
|
||||
# Log deletion error, but don't overwrite original error if one occurred
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
Asynchronous OCR workflow execution with optimized performance.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
# 1. Upload file with streaming
|
||||
file_id = await self._upload_file_async(session)
|
||||
|
||||
# 2. Get signed URL
|
||||
signed_url = await self._get_signed_url_async(session, file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = await self._process_ocr_async(session, signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Cleanup - always attempt file deletion
|
||||
if file_id:
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Process multiple files concurrently for maximum performance.
|
||||
|
||||
Args:
|
||||
loaders: List of MistralLoader instances
|
||||
|
||||
Returns:
|
||||
List of document lists, one for each loader
|
||||
"""
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(f"Starting concurrent processing of {len(loaders)} files")
|
||||
start_time = time.time()
|
||||
|
||||
# Process all files concurrently
|
||||
tasks = [loader.load_async() for loader in loaders]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in results
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
93
backend/open_webui/retrieval/loaders/tavily.py
Normal file
93
backend/open_webui/retrieval/loaders/tavily.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Literal, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class TavilyLoader(BaseLoader):
|
||||
"""Extract web page content from URLs using Tavily Extract API.
|
||||
|
||||
This is a LangChain document loader that uses Tavily's Extract API to
|
||||
retrieve content from web pages and return it as Document objects.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
continue_on_failure: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Tavily Extract client.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
include_images: Whether to include images in the extraction.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
advanced extraction retrieves more data, including tables and
|
||||
embedded content, with higher success but may increase latency.
|
||||
basic costs 1 credit per 5 successful URL extractions,
|
||||
advanced costs 2 credits per 5 successful URL extractions.
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
if not urls:
|
||||
raise ValueError("At least one URL must be provided.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.urls = urls if isinstance(urls, list) else [urls]
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_url = "https://api.tavily.com/extract"
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
batch_urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
# Use string for single URL, array for multiple URLs
|
||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
||||
# Make the API call
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
# Process successful results
|
||||
for result in response_data.get("results", []):
|
||||
url = result.get("url", "")
|
||||
content = result.get("raw_content", "")
|
||||
if not content:
|
||||
log.warning(f"No content extracted from {url}")
|
||||
continue
|
||||
# Add URLs as metadata
|
||||
metadata = {"source": url}
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
for failed in response_data.get("failed_results", []):
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
log.error(f"Failed to extract content from {url}: {error}")
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
|
@ -62,12 +62,17 @@ class YoutubeLoader:
|
|||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self.language = language
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
# Ensure language is a list
|
||||
if isinstance(language, str):
|
||||
self.language = [language]
|
||||
else:
|
||||
self.language = language
|
||||
self.language = list(language)
|
||||
|
||||
# Add English as fallback if not already in the list
|
||||
if "en" not in self.language:
|
||||
self.language.append("en")
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load YouTube transcripts into `Document` objects."""
|
||||
|
|
@ -101,17 +106,31 @@ class YoutubeLoader:
|
|||
log.exception("Loading YouTube transcript failed")
|
||||
return []
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(self.language)
|
||||
except NoTranscriptFound:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
# Try each language in order of priority
|
||||
for lang in self.language:
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([lang])
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
transcript_text = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: transcript_piece.text.strip(" "),
|
||||
transcript_pieces,
|
||||
)
|
||||
)
|
||||
return [Document(page_content=transcript_text, metadata=self._metadata)]
|
||||
except NoTranscriptFound:
|
||||
log.debug(f"No transcript found for language '{lang}'")
|
||||
continue
|
||||
except Exception as e:
|
||||
log.info(f"Error finding transcript for language '{lang}'")
|
||||
raise e
|
||||
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
|
||||
transcript = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: transcript_piece["text"].strip(" "),
|
||||
transcript_pieces,
|
||||
)
|
||||
# If we get here, all languages failed
|
||||
languages_tried = ", ".join(self.language)
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(
|
||||
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
return [Document(page_content=transcript, metadata=self._metadata)]
|
||||
|
|
|
|||
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||
pass
|
||||
|
|
@ -1,13 +1,21 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
class ColBERT:
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT(BaseReranker):
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
print("ColBERT: Loading model", name)
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
|
|
|
|||
60
backend/open_webui/retrieval/models/external.py
Normal file
60
backend/open_webui/retrieval/models/external.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import logging
|
||||
import requests
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalReranker(BaseReranker):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
url: str = "http://localhost:8080/v1/rerank",
|
||||
model: str = "reranker",
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.url = url
|
||||
self.model = model
|
||||
|
||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": len(docs),
|
||||
}
|
||||
|
||||
try:
|
||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||
log.info(f"ExternalReranker:predict:query {query}")
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
if "results" in data:
|
||||
sorted_results = sorted(data["results"], key=lambda x: x["index"])
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
else:
|
||||
log.error("No results found in external reranking response")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error in external reranking: {e}")
|
||||
return None
|
||||
|
|
@ -1,27 +1,36 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import asyncio
|
||||
import requests
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
|
||||
from open_webui.retrieval.vector.main import GetResult
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
OFFLINE_MODE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
from open_webui.config import (
|
||||
RAG_EMBEDDING_QUERY_PREFIX,
|
||||
RAG_EMBEDDING_CONTENT_PREFIX,
|
||||
RAG_EMBEDDING_PREFIX_FIELD_NAME,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
|
@ -46,7 +55,7 @@ class VectorSearchRetriever(BaseRetriever):
|
|||
) -> list[Document]:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=self.collection_name,
|
||||
vectors=[self.embedding_function(query)],
|
||||
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
|
|
@ -69,6 +78,7 @@ def query_doc(
|
|||
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
|
||||
):
|
||||
try:
|
||||
log.debug(f"query_doc:doc {collection_name}")
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=collection_name,
|
||||
vectors=[query_embedding],
|
||||
|
|
@ -80,24 +90,40 @@ def query_doc(
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_doc(collection_name: str, user: UserModel = None):
|
||||
try:
|
||||
log.debug(f"get_doc:doc {collection_name}")
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting doc {collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
collection_result: GetResult,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
) -> dict:
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=result.documents[0],
|
||||
metadatas=result.metadatas[0],
|
||||
texts=collection_result.documents[0],
|
||||
metadatas=collection_result.metadatas[0],
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
|
||||
|
|
@ -107,12 +133,23 @@ def query_doc_with_hybrid_search(
|
|||
top_k=k,
|
||||
)
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
if hybrid_bm25_weight <= 0:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[vector_search_retriever], weights=[1.0]
|
||||
)
|
||||
elif hybrid_bm25_weight >= 1:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever], weights=[1.0]
|
||||
)
|
||||
else:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever],
|
||||
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
|
||||
)
|
||||
|
||||
compressor = RerankCompressor(
|
||||
embedding_function=embedding_function,
|
||||
top_n=k,
|
||||
top_n=k_reranker,
|
||||
reranking_function=reranking_function,
|
||||
r_score=r,
|
||||
)
|
||||
|
|
@ -122,10 +159,23 @@ def query_doc_with_hybrid_search(
|
|||
)
|
||||
|
||||
result = compression_retriever.invoke(query)
|
||||
|
||||
distances = [d.metadata.get("score") for d in result]
|
||||
documents = [d.page_content for d in result]
|
||||
metadatas = [d.metadata for d in result]
|
||||
|
||||
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
|
||||
if k < k_reranker:
|
||||
sorted_items = sorted(
|
||||
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
sorted_items = sorted_items[:k]
|
||||
distances, documents, metadatas = map(list, zip(*sorted_items))
|
||||
|
||||
result = {
|
||||
"distances": [[d.metadata.get("score") for d in result]],
|
||||
"documents": [[d.page_content for d in result]],
|
||||
"metadatas": [[d.metadata for d in result]],
|
||||
"distances": [distances],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
|
||||
log.info(
|
||||
|
|
@ -134,52 +184,88 @@ def query_doc_with_hybrid_search(
|
|||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> list[dict]:
|
||||
def merge_get_results(get_results: list[dict]) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
combined_ids = []
|
||||
|
||||
for data in query_results:
|
||||
combined_distances.extend(data["distances"][0])
|
||||
for data in get_results:
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
|
||||
# Create a list of tuples (distance, document, metadata)
|
||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# We don't have anything :-(
|
||||
if not combined:
|
||||
sorted_distances = []
|
||||
sorted_documents = []
|
||||
sorted_metadatas = []
|
||||
else:
|
||||
# Unzip the sorted list
|
||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
||||
|
||||
# Slicing the lists to include only k elements
|
||||
sorted_distances = list(sorted_distances)[:k]
|
||||
sorted_documents = list(sorted_documents)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
combined_ids.extend(data["ids"][0])
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"distances": [sorted_distances],
|
||||
"documents": [sorted_documents],
|
||||
"metadatas": [sorted_metadatas],
|
||||
"documents": [combined_documents],
|
||||
"metadatas": [combined_metadatas],
|
||||
"ids": [combined_ids],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = dict() # To store documents with unique document hashes
|
||||
|
||||
for data in query_results:
|
||||
distances = data["distances"][0]
|
||||
documents = data["documents"][0]
|
||||
metadatas = data["metadatas"][0]
|
||||
|
||||
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||
if isinstance(document, str):
|
||||
doc_hash = hashlib.sha256(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
if doc_hash not in combined.keys():
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
continue # if doc is new, no further comparison is needed
|
||||
|
||||
# if doc is alredy in, but new distance is better, update
|
||||
if distance > combined[doc_hash][0]:
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
|
||||
combined = list(combined.values())
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Slice to keep only the top k elements
|
||||
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
"documents": [list(sorted_documents)],
|
||||
"metadatas": [list(sorted_metadatas)],
|
||||
}
|
||||
|
||||
|
||||
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = get_doc(collection_name=collection_name)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
|
|
@ -187,29 +273,49 @@ def query_collection(
|
|||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
for query in queries:
|
||||
query_embedding = embedding_function(query)
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
error = False
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
def process_query_collection(collection_name, query_embedding):
|
||||
try:
|
||||
if collection_name:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
return result.model_dump(), None
|
||||
return None, None
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
return None, e
|
||||
|
||||
# Generate all query embeddings (in one call)
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
log.debug(
|
||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = []
|
||||
for query_embedding in query_embeddings:
|
||||
for collection_name in collection_names:
|
||||
result = executor.submit(
|
||||
process_query_collection, collection_name, query_embedding
|
||||
)
|
||||
future_results.append(result)
|
||||
task_results = [future.result() for future in future_results]
|
||||
|
||||
for result, err in task_results:
|
||||
if err is not None:
|
||||
error = True
|
||||
elif result is not None:
|
||||
results.append(result)
|
||||
|
||||
if error and not results:
|
||||
log.warning("All collection queries failed. No results returned.")
|
||||
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
|
|
@ -218,39 +324,74 @@ def query_collection_with_hybrid_search(
|
|||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
# Fetch collection data once per collection sequentially
|
||||
# Avoid fetching the same data multiple times later
|
||||
collection_results = {}
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
for query in queries:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||
log.debug(
|
||||
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
|
||||
)
|
||||
error = True
|
||||
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=collection_name
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to fetch collection {collection_name}: {e}")
|
||||
collection_results[collection_name] = None
|
||||
|
||||
if error:
|
||||
log.info(
|
||||
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
|
||||
)
|
||||
|
||||
def process_query(collection_name, query):
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
collection_result=collection_results[collection_name],
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
)
|
||||
return result, None
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection with hybrid_search: {e}")
|
||||
return None, e
|
||||
|
||||
# Prepare tasks for all collections and queries
|
||||
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
|
||||
tasks = [
|
||||
(cn, q)
|
||||
for cn in collection_names
|
||||
if collection_results[cn] is not None
|
||||
for q in queries
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
|
||||
task_results = [future.result() for future in future_results]
|
||||
|
||||
for result, err in task_results:
|
||||
if err is not None:
|
||||
error = True
|
||||
elif result is not None:
|
||||
results.append(result)
|
||||
|
||||
if error and not results:
|
||||
raise Exception(
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
|
||||
)
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
|
|
@ -260,58 +401,132 @@ def get_embedding_function(
|
|||
url,
|
||||
key,
|
||||
embedding_batch_size,
|
||||
azure_api_version=None,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
func = lambda query, user=None: generate_embeddings(
|
||||
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||
query, **({"prompt": prefix} if prefix else {})
|
||||
).tolist()
|
||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
prefix=prefix,
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
azure_api_version=azure_api_version,
|
||||
)
|
||||
|
||||
def generate_multiple(query, user, func):
|
||||
def generate_multiple(query, prefix, user, func):
|
||||
if isinstance(query, list):
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(
|
||||
func(query[i : i + embedding_batch_size], user=user)
|
||||
func(
|
||||
query[i : i + embedding_batch_size],
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
return func(query, user)
|
||||
return func(query, prefix, user)
|
||||
|
||||
return lambda query, user=None: generate_multiple(query, user, func)
|
||||
return lambda query, prefix=None, user=None: generate_multiple(
|
||||
query, prefix, user, func
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
||||
def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
k_reranker,
|
||||
r,
|
||||
hybrid_bm25_weight,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for file in files:
|
||||
if file.get("context") == "full":
|
||||
|
||||
context = None
|
||||
if file.get("docs"):
|
||||
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
context = {
|
||||
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
||||
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
||||
}
|
||||
elif file.get("context") == "full":
|
||||
# Manual Full Mode Toggle
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||
}
|
||||
else:
|
||||
context = None
|
||||
elif (
|
||||
file.get("type") != "web_search"
|
||||
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if file.get("type") == "collection":
|
||||
file_ids = file.get("data", {}).get("file_ids", [])
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
for file_id in file_ids:
|
||||
file_object = Files.get_file_by_id(file_id)
|
||||
|
||||
if file_object:
|
||||
documents.append(file_object.data.get("content", ""))
|
||||
metadatas.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
)
|
||||
|
||||
context = {
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
|
||||
elif file.get("id"):
|
||||
file_object = Files.get_file_by_id(file.get("id"))
|
||||
if file_object:
|
||||
context = {
|
||||
"documents": [[file_object.data.get("content", "")]],
|
||||
"metadatas": [
|
||||
[
|
||||
{
|
||||
"file_id": file.get("id"),
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
elif file.get("file").get("data"):
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [
|
||||
[file.get("file").get("data", {}).get("metadata", {})]
|
||||
],
|
||||
}
|
||||
else:
|
||||
collection_names = []
|
||||
if file.get("type") == "collection":
|
||||
if file.get("legacy"):
|
||||
|
|
@ -331,42 +546,52 @@ def get_sources_from_files(
|
|||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
if full_context:
|
||||
try:
|
||||
context = get_all_items_from_collections(collection_names)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
else:
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
sources = []
|
||||
|
|
@ -435,9 +660,17 @@ def generate_openai_batch_embeddings(
|
|||
texts: list[str],
|
||||
url: str = "https://api.openai.com/v1",
|
||||
key: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
headers={
|
||||
|
|
@ -454,7 +687,7 @@ def generate_openai_batch_embeddings(
|
|||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
json=json_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
|
@ -463,14 +696,80 @@ def generate_openai_batch_embeddings(
|
|||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_azure_openai_batch_embeddings(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
version: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
||||
|
||||
for _ in range(5):
|
||||
r = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": key,
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
if r.status_code == 429:
|
||||
retry = float(r.headers.get("Retry-After", "1"))
|
||||
time.sleep(retry)
|
||||
continue
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
r = requests.post(
|
||||
f"{url}/api/embed",
|
||||
headers={
|
||||
|
|
@ -487,7 +786,7 @@ def generate_ollama_batch_embeddings(
|
|||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
json=json_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
|
@ -497,37 +796,55 @@ def generate_ollama_batch_embeddings(
|
|||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
def generate_embeddings(
|
||||
engine: str,
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
prefix: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
url = kwargs.get("url", "")
|
||||
key = kwargs.get("key", "")
|
||||
user = kwargs.get("user")
|
||||
|
||||
if engine == "ollama":
|
||||
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||
)
|
||||
text = [f"{prefix}{text_element}" for text_element in text]
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
text = f"{prefix}{text}"
|
||||
|
||||
if engine == "ollama":
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": text if isinstance(text, list) else [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"prefix": prefix,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||
|
||||
embeddings = generate_openai_batch_embeddings(
|
||||
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "azure_openai":
|
||||
azure_api_version = kwargs.get("azure_api_version", "")
|
||||
embeddings = generate_azure_openai_batch_embeddings(
|
||||
model,
|
||||
text if isinstance(text, list) else [text],
|
||||
url,
|
||||
key,
|
||||
azure_api_version,
|
||||
prefix,
|
||||
user,
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
|
|
@ -563,13 +880,15 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
else:
|
||||
from sentence_transformers import util
|
||||
|
||||
query_embedding = self.embedding_function(query)
|
||||
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
||||
document_embedding = self.embedding_function(
|
||||
[doc.page_content for doc in documents]
|
||||
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
||||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
||||
docs_with_scores = list(zip(documents, scores.tolist()))
|
||||
docs_with_scores = list(
|
||||
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
|
||||
)
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
44
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file → Executable file
44
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file → Executable file
|
|
@ -1,10 +1,16 @@
|
|||
import chromadb
|
||||
import logging
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
|
|
@ -16,9 +22,13 @@ from open_webui.config import (
|
|||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
|
|
@ -70,10 +80,16 @@ class ChromaClient:
|
|||
n_results=limit,
|
||||
)
|
||||
|
||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
||||
distances: list = result["distances"][0]
|
||||
distances = [2 - dist for dist in distances]
|
||||
distances = [[dist / 2 for dist in distances]]
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"distances": distances,
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
|
|
@ -102,8 +118,7 @@ class ChromaClient:
|
|||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except:
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
|
@ -162,12 +177,19 @@ class ChromaClient:
|
|||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
except Exception as e:
|
||||
# If collection doesn't exist, that's fine - nothing to delete
|
||||
log.debug(
|
||||
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
|
||||
)
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
|
|
|
|||
300
backend/open_webui/retrieval/vector/dbs/elasticsearch.py
Normal file
300
backend/open_webui/retrieval/vector/dbs/elasticsearch.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
ELASTICSEARCH_API_KEY,
|
||||
ELASTICSEARCH_USERNAME,
|
||||
ELASTICSEARCH_PASSWORD,
|
||||
ELASTICSEARCH_CLOUD_ID,
|
||||
ELASTICSEARCH_INDEX_PREFIX,
|
||||
SSL_ASSERT_FINGERPRINT,
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchClient(VectorDBBase):
|
||||
"""
|
||||
Important:
|
||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||
an index for each file but store it as a text field, while seperating to different index
|
||||
baesd on the embedding length.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
|
||||
self.client = Elasticsearch(
|
||||
hosts=[ELASTICSEARCH_URL],
|
||||
ca_certs=ELASTICSEARCH_CA_CERTS,
|
||||
api_key=ELASTICSEARCH_API_KEY,
|
||||
cloud_id=ELASTICSEARCH_CLOUD_ID,
|
||||
basic_auth=(
|
||||
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
|
||||
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
|
||||
else None
|
||||
),
|
||||
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
|
||||
)
|
||||
|
||||
# Status: works
|
||||
def _get_index_name(self, dimension: int) -> str:
|
||||
return f"{self.index_prefix}_d{str(dimension)}"
|
||||
|
||||
# Status: works
|
||||
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||
if not result:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=[ids],
|
||||
distances=[distances],
|
||||
documents=[documents],
|
||||
metadatas=[metadatas],
|
||||
)
|
||||
|
||||
# Status: works
|
||||
def _create_index(self, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"dynamic_templates": [
|
||||
{
|
||||
"strings": {
|
||||
"match_mapping_type": "string",
|
||||
"mapping": {"type": "keyword"},
|
||||
}
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"collection": {"type": "keyword"},
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
},
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=self._get_index_name(dimension), body=body)
|
||||
|
||||
# Status: works
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : min(i + batch_size, len(items))]
|
||||
|
||||
# Status: works
|
||||
def has_collection(self, collection_name) -> bool:
|
||||
query_body = {"query": {"bool": {"filter": []}}}
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
||||
|
||||
return result.body["count"] > 0
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
query = {"query": {"term": {"collection": collection_name}}}
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
|
||||
# Status: works
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(len(vectors[0])), body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
# Status: only tested halfwat
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}*",
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# Status: works
|
||||
def _has_index(self, dimension: int):
|
||||
return self.client.indices.exists(
|
||||
index=self._get_index_name(dimension=dimension)
|
||||
)
|
||||
|
||||
def get_or_create_index(self, dimension: int):
|
||||
if not self._has_index(dimension=dimension):
|
||||
self._create_index(dimension=dimension)
|
||||
|
||||
# Status: works
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||
|
||||
return self._scan_result_to_get_result(results)
|
||||
|
||||
# Status: works
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
|
||||
# Upsert documents using the update API with doc_as_upsert=True.
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(dimension=len(item["vector"])),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
|
||||
# Delete specific documents from a collection by filtering on both collection and document IDs.
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
|
||||
}
|
||||
# logic based on chromaDB
|
||||
if ids:
|
||||
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
|
||||
elif filter:
|
||||
for field, value in filter.items():
|
||||
query["query"]["bool"]["filter"].append(
|
||||
{"term": {f"metadata.{field}": value}}
|
||||
)
|
||||
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
|
@ -1,30 +1,42 @@
|
|||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
MILVUS_TOKEN,
|
||||
MILVUS_INDEX_TYPE,
|
||||
MILVUS_METRIC_TYPE,
|
||||
MILVUS_HNSW_M,
|
||||
MILVUS_HNSW_EFCONSTRUCTION,
|
||||
MILVUS_IVF_FLAT_NLIST,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
if MILVUS_TOKEN is None:
|
||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||
else:
|
||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN)
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_documents = []
|
||||
|
|
@ -33,11 +45,9 @@ class MilvusClient:
|
|||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
|
|
@ -51,24 +61,23 @@ class MilvusClient:
|
|||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_distances = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_distances.append(item.get("distance"))
|
||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||
# https://milvus.io/docs/de/metric.md
|
||||
_dist = (item.get("distance") + 1.0) / 2.0
|
||||
_distances.append(_dist)
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
|
|
@ -101,11 +110,39 @@ class MilvusClient:
|
|||
)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
|
||||
# Use configurations from config.py
|
||||
index_type = MILVUS_INDEX_TYPE.upper()
|
||||
metric_type = MILVUS_METRIC_TYPE.upper()
|
||||
|
||||
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
|
||||
|
||||
index_creation_params = {}
|
||||
if index_type == "HNSW":
|
||||
index_creation_params = {
|
||||
"M": MILVUS_HNSW_M,
|
||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
||||
}
|
||||
log.info(f"HNSW params: {index_creation_params}")
|
||||
elif index_type == "IVF_FLAT":
|
||||
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
||||
log.info(f"IVF_FLAT params: {index_creation_params}")
|
||||
elif index_type in ["FLAT", "AUTOINDEX"]:
|
||||
log.info(f"Using {index_type} index with no specific build-time params.")
|
||||
else:
|
||||
log.warning(
|
||||
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
|
||||
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
|
||||
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
|
||||
)
|
||||
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
|
||||
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
metric_type="COSINE",
|
||||
params={"M": 16, "efConstruction": 100},
|
||||
index_type=index_type,
|
||||
metric_type=metric_type,
|
||||
params=index_creation_params,
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
|
|
@ -113,6 +150,9 @@ class MilvusClient:
|
|||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
log.info(
|
||||
f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
|
|
@ -133,82 +173,113 @@ class MilvusClient:
|
|||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
||||
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
||||
# For simplicity, not adding configurable search_params here, but could be extended.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
# search_params=search_params # Potentially add later if needed
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
return None
|
||||
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
max_limit = 16383 # The maximum number of records per request
|
||||
all_results = []
|
||||
|
||||
if limit is None:
|
||||
limit = float("inf") # Use infinity as a placeholder for no limit
|
||||
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
|
||||
# Let's set a practical high number if "all" is intended, or handle true pagination.
|
||||
# For now, if limit is None, we'll fetch in batches up to a very large number.
|
||||
# This part could be refined based on expected use cases for "get all".
|
||||
# For this function signature, None implies "as many as possible" up to Milvus limits.
|
||||
limit = (
|
||||
16384 * 10
|
||||
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
|
||||
log.info(
|
||||
f"Limit not specified for query, fetching up to {limit} results in batches."
|
||||
)
|
||||
|
||||
# Initialize offset and remaining to handle pagination
|
||||
offset = 0
|
||||
remaining = limit
|
||||
|
||||
try:
|
||||
log.info(
|
||||
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
|
||||
)
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
print("remaining", remaining)
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
max_limit, remaining if isinstance(remaining, int) else max_limit
|
||||
)
|
||||
log.debug(
|
||||
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
|
||||
)
|
||||
|
||||
results = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
output_fields=["*"],
|
||||
output_fields=[
|
||||
"id",
|
||||
"data",
|
||||
"metadata",
|
||||
], # Explicitly list needed fields. Vector not usually needed in query.
|
||||
limit=current_fetch,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
if not results:
|
||||
log.debug("No more results from query.")
|
||||
break
|
||||
|
||||
all_results.extend(results)
|
||||
results_count = len(results)
|
||||
remaining -= (
|
||||
results_count # Decrease remaining by the number of items fetched
|
||||
)
|
||||
log.debug(f"Fetched {results_count} results in this batch.")
|
||||
|
||||
if isinstance(remaining, int):
|
||||
remaining -= results_count
|
||||
|
||||
offset += results_count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count
|
||||
# Break the loop if the results returned are less than the requested fetch count (means end of data)
|
||||
if results_count < current_fetch:
|
||||
log.debug(
|
||||
"Fetched less than requested, assuming end of results for this query."
|
||||
)
|
||||
break
|
||||
|
||||
print(all_results)
|
||||
log.info(f"Total results from query: {len(all_results)}")
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(
|
||||
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter='id != ""',
|
||||
log.warning(
|
||||
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
|
||||
)
|
||||
return self._result_to_get_result([result])
|
||||
# Using query with a trivial filter to get all items.
|
||||
# This will use the paginated query logic.
|
||||
return self.query(collection_name=collection_name, filter={}, limit=None)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
|
|
@ -216,10 +287,23 @@ class MilvusClient:
|
|||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
|
||||
)
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection without items to determine vector dimension."
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
|
|
@ -239,10 +323,23 @@ class MilvusClient:
|
|||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
|
||||
)
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection for upsert without items to determine vector dimension."
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
|
|
@ -262,30 +359,55 @@ class MilvusClient:
|
|||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
# Delete the items from the collection based on the ids or filter.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
return None
|
||||
|
||||
if ids:
|
||||
log.info(
|
||||
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
|
||||
)
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
# Convert the filter dictionary to a string using JSON_CONTAINS.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
|
||||
)
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
# Resets the database. This will delete all collections and item entries that match the prefix.
|
||||
log.warning(
|
||||
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
|
||||
)
|
||||
collection_names = self.client.list_collections()
|
||||
for collection_name in collection_names:
|
||||
if collection_name.startswith(self.collection_prefix):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
deleted_collections = []
|
||||
for collection_name_full in collection_names:
|
||||
if collection_name_full.startswith(self.collection_prefix):
|
||||
try:
|
||||
self.client.drop_collection(collection_name=collection_name_full)
|
||||
deleted_collections.append(collection_name_full)
|
||||
log.info(f"Deleted collection: {collection_name_full}")
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {collection_name_full}: {e}")
|
||||
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
|
|
@ -11,7 +17,7 @@ from open_webui.config import (
|
|||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
|
|
@ -21,7 +27,13 @@ class OpenSearchClient:
|
|||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
)
|
||||
|
||||
def _get_index_name(self, collection_name: str) -> str:
|
||||
return f"{self.index_prefix}_{collection_name}"
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
|
@ -31,9 +43,12 @@ class OpenSearchClient:
|
|||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
|
|
@ -46,72 +61,88 @@ class OpenSearchClient:
|
|||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
ids=[ids],
|
||||
distances=[distances],
|
||||
documents=[documents],
|
||||
metadatas=[metadatas],
|
||||
)
|
||||
|
||||
def _create_index(self, index_name: str, dimension: int):
|
||||
def _create_index(self, collection_name: str, dimension: int):
|
||||
body = {
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": true,
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "faiss",
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "ip", # Use inner product to approximate cosine similarity
|
||||
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
|
||||
"engine": "faiss",
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
"parameters": {
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
||||
self.client.indices.create(
|
||||
index=self._get_index_name(collection_name), body=body
|
||||
)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
def has_collection(self, index_name: str) -> bool:
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# has_collection here means has index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
||||
return self.client.indices.exists(index=self._get_index_name(collection_name))
|
||||
|
||||
def delete_colleciton(self, index_name: str):
|
||||
def delete_collection(self, collection_name: str):
|
||||
# delete_collection here means delete index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
||||
self.client.indices.delete(index=self._get_index_name(collection_name))
|
||||
|
||||
def search(
|
||||
self, index_name: str, vectors: list[list[float]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
)
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
|
||||
"params": {
|
||||
"field": "vector",
|
||||
"query_value": vectors[0],
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
|
|
@ -125,13 +156,15 @@ class OpenSearchClient:
|
|||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"match": {"metadata." + str(field): value}}
|
||||
)
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
index=self._get_index_name(collection_name),
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
|
@ -141,64 +174,88 @@ class OpenSearchClient:
|
|||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_or_create_index(self, index_name: str, dimension: int):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension)
|
||||
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
|
||||
if not self.has_collection(collection_name):
|
||||
self._create_index(collection_name, dimension)
|
||||
|
||||
def get(self, index_name: str) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
def insert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
"_op_type": "index",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
bulk(self.client, actions)
|
||||
|
||||
def upsert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
bulk(self.client, actions)
|
||||
|
||||
def delete(self, index_name: str, ids: list[str]):
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
||||
for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
if ids:
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "delete",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": id,
|
||||
}
|
||||
for id in ids
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
elif filter:
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
}
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"match": {"metadata." + str(field): value}}
|
||||
)
|
||||
self.client.delete_by_query(
|
||||
index=self._get_index_name(collection_name), body=query_body
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
|
|
@ -21,12 +22,22 @@ from pgvector.sqlalchemy import Vector
|
|||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
|
@ -38,7 +49,7 @@ class DocumentChunk(Base):
|
|||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
|
|
@ -82,10 +93,10 @@ class PgvectorClient:
|
|||
)
|
||||
)
|
||||
self.session.commit()
|
||||
print("Initialization complete.")
|
||||
log.info("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during initialization: {e}")
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
|
|
@ -130,9 +141,8 @@ class PgvectorClient:
|
|||
# Pad the vector with zeros
|
||||
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
||||
elif current_length > VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
||||
)
|
||||
# Truncate the vector to VECTOR_LENGTH
|
||||
vector = vector[:VECTOR_LENGTH]
|
||||
return vector
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
|
|
@ -150,12 +160,12 @@ class PgvectorClient:
|
|||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during insert: {e}")
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
|
|
@ -184,10 +194,12 @@ class PgvectorClient:
|
|||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during upsert: {e}")
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
|
|
@ -270,7 +282,9 @@ class PgvectorClient:
|
|||
for row in results:
|
||||
qid = int(row.qid)
|
||||
ids[qid].append(row.id)
|
||||
distances[qid].append(row.distance)
|
||||
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
|
||||
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
||||
distances[qid].append((2.0 - row.distance) / 2.0)
|
||||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
|
|
@ -278,7 +292,7 @@ class PgvectorClient:
|
|||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
|
|
@ -310,7 +324,7 @@ class PgvectorClient:
|
|||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during query: {e}")
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
|
|
@ -334,7 +348,7 @@ class PgvectorClient:
|
|||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
print(f"Error during get: {e}")
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
|
|
@ -356,22 +370,22 @@ class PgvectorClient:
|
|||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during delete: {e}")
|
||||
log.exception(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during reset: {e}")
|
||||
log.exception(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
@ -387,9 +401,9 @@ class PgvectorClient:
|
|||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
print(f"Error checking collection existence: {e}")
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
print(f"Collection '{collection_name}' deleted.")
|
||||
log.info(f"Collection '{collection_name}' deleted.")
|
||||
|
|
|
|||
508
backend/open_webui/retrieval/vector/dbs/pinecone.py
Normal file
508
backend/open_webui/retrieval/vector/dbs/pinecone.py
Normal file
|
|
@ -0,0 +1,508 @@
|
|||
from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
import time # for measuring elapsed time
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
import asyncio # for async upserts
|
||||
import functools # for partial binding in async tasks
|
||||
|
||||
import concurrent.futures # for parallel batch upserts
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_ENVIRONMENT,
|
||||
PINECONE_INDEX_NAME,
|
||||
PINECONE_DIMENSION,
|
||||
PINECONE_METRIC,
|
||||
PINECONE_CLOUD,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
||||
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class PineconeClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
|
||||
# Validate required configuration
|
||||
self._validate_config()
|
||||
|
||||
# Store configuration values
|
||||
self.api_key = PINECONE_API_KEY
|
||||
self.environment = PINECONE_ENVIRONMENT
|
||||
self.index_name = PINECONE_INDEX_NAME
|
||||
self.dimension = PINECONE_DIMENSION
|
||||
self.metric = PINECONE_METRIC
|
||||
self.cloud = PINECONE_CLOUD
|
||||
|
||||
# Initialize Pinecone client for improved performance
|
||||
self.client = Pinecone(api_key=self.api_key)
|
||||
|
||||
# Persistent executor for batch operations
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# Create index if it doesn't exist
|
||||
self._initialize_index()
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate that all required configuration variables are set."""
|
||||
missing_vars = []
|
||||
if not PINECONE_API_KEY:
|
||||
missing_vars.append("PINECONE_API_KEY")
|
||||
if not PINECONE_ENVIRONMENT:
|
||||
missing_vars.append("PINECONE_ENVIRONMENT")
|
||||
if not PINECONE_INDEX_NAME:
|
||||
missing_vars.append("PINECONE_INDEX_NAME")
|
||||
if not PINECONE_DIMENSION:
|
||||
missing_vars.append("PINECONE_DIMENSION")
|
||||
if not PINECONE_CLOUD:
|
||||
missing_vars.append("PINECONE_CLOUD")
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Required configuration missing: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
def _initialize_index(self) -> None:
|
||||
"""Initialize the Pinecone index."""
|
||||
try:
|
||||
# Check if index exists
|
||||
if self.index_name not in self.client.list_indexes().names():
|
||||
log.info(f"Creating Pinecone index '{self.index_name}'...")
|
||||
self.client.create_index(
|
||||
name=self.index_name,
|
||||
dimension=self.dimension,
|
||||
metric=self.metric,
|
||||
spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
|
||||
)
|
||||
log.info(f"Successfully created Pinecone index '{self.index_name}'")
|
||||
else:
|
||||
log.info(f"Using existing Pinecone index '{self.index_name}'")
|
||||
|
||||
# Connect to the index
|
||||
self.index = self.client.Index(self.index_name)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize Pinecone index: {e}")
|
||||
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], collection_name_with_prefix: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert VectorItem objects to Pinecone point format."""
|
||||
points = []
|
||||
for item in items:
|
||||
# Start with any existing metadata or an empty dict
|
||||
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
|
||||
|
||||
# Add text to metadata if available
|
||||
if "text" in item:
|
||||
metadata["text"] = item["text"]
|
||||
|
||||
# Always add collection_name to metadata for filtering
|
||||
metadata["collection_name"] = collection_name_with_prefix
|
||||
|
||||
point = {
|
||||
"id": item["id"],
|
||||
"values": item["vector"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
points.append(point)
|
||||
return points
|
||||
|
||||
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
|
||||
"""Get the collection name with prefix."""
|
||||
return f"{self.collection_prefix}_{collection_name}"
|
||||
|
||||
def _normalize_distance(self, score: float) -> float:
|
||||
"""Normalize distance score based on the metric used."""
|
||||
if self.metric.lower() == "cosine":
|
||||
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
|
||||
return (score + 1.0) / 2.0
|
||||
elif self.metric.lower() in ["euclidean", "dotproduct"]:
|
||||
# These are already suitable for ranking (smaller is better for Euclidean)
|
||||
return score
|
||||
else:
|
||||
# For other metrics, use as is
|
||||
return score
|
||||
|
||||
def _result_to_get_result(self, matches: list) -> GetResult:
|
||||
"""Convert Pinecone matches to GetResult format."""
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in matches:
|
||||
metadata = getattr(match, "metadata", {}) or {}
|
||||
ids.append(match.id if hasattr(match, "id") else match["id"])
|
||||
documents.append(metadata.get("text", ""))
|
||||
metadatas.append(metadata)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if a collection exists by searching for at least one item."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for at least 1 item with this collection name in metadata
|
||||
response = self.index.query(
|
||||
vector=[0.0] * self.dimension, # dummy vector
|
||||
top_k=1,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
include_metadata=False,
|
||||
)
|
||||
matches = getattr(response, "matches", []) or []
|
||||
return len(matches) > 0
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error checking collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection by removing all vectors with the collection name in metadata."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
try:
|
||||
self.index.delete(filter={"collection_name": collection_name_with_prefix})
|
||||
log.info(
|
||||
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Parallelize batch inserts for performance
|
||||
executor = self._executor
|
||||
futures = []
|
||||
for i in range(0, len(points), BATCH_SIZE):
|
||||
batch = points[i : i + BATCH_SIZE]
|
||||
futures.append(executor.submit(self.index.upsert, vectors=batch))
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
log.error(f"Error inserting batch: {e}")
|
||||
raise
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Upsert (insert or update) vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Parallelize batch upserts for performance
|
||||
executor = self._executor
|
||||
futures = []
|
||||
for i in range(0, len(points), BATCH_SIZE):
|
||||
batch = points[i : i + BATCH_SIZE]
|
||||
futures.append(executor.submit(self.index.upsert, vectors=batch))
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
log.error(f"Error upserting batch: {e}")
|
||||
raise
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Async version of insert using asyncio and run_in_executor for improved performance."""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Create batches
|
||||
batches = [
|
||||
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
|
||||
]
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = [
|
||||
loop.run_in_executor(
|
||||
None, functools.partial(self.index.upsert, vectors=batch)
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"Error in async insert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Async version of upsert using asyncio and run_in_executor for improved performance."""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Create batches
|
||||
batches = [
|
||||
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
|
||||
]
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = [
|
||||
loop.run_in_executor(
|
||||
None, functools.partial(self.index.upsert, vectors=batch)
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"Error in async upsert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
if not vectors or not vectors[0]:
|
||||
log.warning("No vectors provided for search")
|
||||
return None
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
|
||||
try:
|
||||
# Search using the first vector (assuming this is the intended behavior)
|
||||
query_vector = vectors[0]
|
||||
|
||||
# Perform the search
|
||||
query_response = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=limit,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
if not matches:
|
||||
# Return empty result if no matches
|
||||
return SearchResult(
|
||||
ids=[[]],
|
||||
documents=[[]],
|
||||
metadatas=[[]],
|
||||
distances=[[]],
|
||||
)
|
||||
|
||||
# Convert to GetResult format
|
||||
get_result = self._result_to_get_result(matches)
|
||||
|
||||
# Calculate normalized distances based on metric
|
||||
distances = [
|
||||
[
|
||||
self._normalize_distance(getattr(match, "score", 0.0))
|
||||
for match in matches
|
||||
]
|
||||
]
|
||||
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=distances,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""Query vectors by metadata filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
|
||||
try:
|
||||
# Create a zero vector for the dimension as Pinecone requires a vector
|
||||
zero_vector = [0.0] * self.dimension
|
||||
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
|
||||
# Perform metadata-only query
|
||||
query_response = self.index.query(
|
||||
vector=zero_vector,
|
||||
filter=pinecone_filter,
|
||||
top_k=limit,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Get all vectors in a collection."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
# Use a zero vector for fetching all entries
|
||||
zero_vector = [0.0] * self.dimension
|
||||
|
||||
# Add filter to only get vectors for this collection
|
||||
query_response = self.index.query(
|
||||
vector=zero_vector,
|
||||
top_k=NO_LIMIT,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by IDs or filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
if ids:
|
||||
# Delete by IDs (in batches for large deletions)
|
||||
for i in range(0, len(ids), BATCH_SIZE):
|
||||
batch_ids = ids[i : i + BATCH_SIZE]
|
||||
# Note: When deleting by ID, we can't filter by collection_name
|
||||
# This is a limitation of Pinecone - be careful with ID uniqueness
|
||||
self.index.delete(ids=batch_ids)
|
||||
log.debug(
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(
|
||||
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
elif filter:
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
# Delete by metadata filter
|
||||
self.index.delete(filter=pinecone_filter)
|
||||
log.info(
|
||||
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
else:
|
||||
log.warning("No ids or filter provided for delete operation")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting from collection '{collection_name}': {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the database by deleting all collections."""
|
||||
try:
|
||||
self.index.delete(delete_all=True)
|
||||
log.info("All vectors successfully deleted from the index.")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to reset Pinecone index: {e}")
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Shut down resources."""
|
||||
try:
|
||||
# The new Pinecone client doesn't need explicit closing
|
||||
pass
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to clean up Pinecone resources: {e}")
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit context manager, ensuring resources are cleaned up."""
|
||||
self.close()
|
||||
|
|
@ -1,25 +1,60 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
QDRANT_URI,
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_PREFER_GRPC,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
class QdrantClient:
|
||||
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.client = (
|
||||
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
if self.QDRANT_URI
|
||||
else None
|
||||
)
|
||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
self.client = None
|
||||
return
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
host = parsed.hostname or self.QDRANT_URI
|
||||
http_port = parsed.port or 6333 # default REST port
|
||||
|
||||
if self.PREFER_GRPC:
|
||||
self.client = Qclient(
|
||||
host=host,
|
||||
port=http_port,
|
||||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
)
|
||||
else:
|
||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
|
|
@ -45,11 +80,13 @@ class QdrantClient:
|
|||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension, distance=models.Distance.COSINE
|
||||
size=dimension,
|
||||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
|
|
@ -94,7 +131,8 @@ class QdrantClient:
|
|||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]],
|
||||
# qdrant distance is [-1, 1], normalize to [0, 1]
|
||||
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
|
|
@ -120,7 +158,7 @@ class QdrantClient:
|
|||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
|
|
|||
712
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
Normal file
712
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
Normal file
|
|
@ -0,0 +1,712 @@
|
|||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import grpc
|
||||
from open_webui.config import (
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_URI,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
self.client = None
|
||||
return
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
host = parsed.hostname or self.QDRANT_URI
|
||||
http_port = parsed.port or 6333 # default REST port
|
||||
|
||||
if self.PREFER_GRPC:
|
||||
self.client = Qclient(
|
||||
host=host,
|
||||
port=http_port,
|
||||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
)
|
||||
else:
|
||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Maps the traditional collection name to multi-tenant collection and tenant ID.
|
||||
|
||||
Returns:
|
||||
tuple: (collection_name, tenant_id)
|
||||
"""
|
||||
# Check for user memory collections
|
||||
tenant_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
return self.MEMORY_COLLECTION, tenant_id
|
||||
|
||||
# Check for file collections
|
||||
elif collection_name.startswith("file-"):
|
||||
return self.FILE_COLLECTION, tenant_id
|
||||
|
||||
# Check for web search collections
|
||||
elif collection_name.startswith("web-search-"):
|
||||
return self.WEB_SEARCH_COLLECTION, tenant_id
|
||||
|
||||
# Handle hash-based collections (YouTube and web URLs)
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
return self.HASH_BASED_COLLECTION, tenant_id
|
||||
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, tenant_id
|
||||
|
||||
def _extract_error_message(self, exception):
|
||||
"""
|
||||
Extract error message from either HTTP or gRPC exceptions
|
||||
|
||||
Returns:
|
||||
tuple: (status_code, error_message)
|
||||
"""
|
||||
# Check if it's an HTTP exception
|
||||
if isinstance(exception, UnexpectedResponse):
|
||||
try:
|
||||
error_data = exception.structured()
|
||||
error_msg = error_data.get("status", {}).get("error", "")
|
||||
return exception.status_code, error_msg
|
||||
except Exception as inner_e:
|
||||
log.error(f"Failed to parse HTTP error: {inner_e}")
|
||||
return exception.status_code, str(exception)
|
||||
|
||||
# Check if it's a gRPC exception
|
||||
elif isinstance(exception, grpc.RpcError):
|
||||
# Extract status code from gRPC error
|
||||
status_code = None
|
||||
if hasattr(exception, "code") and callable(exception.code):
|
||||
status_code = exception.code().value[0]
|
||||
|
||||
# Extract error message
|
||||
error_msg = str(exception)
|
||||
if "details =" in error_msg:
|
||||
# Parse the details line which contains the actual error message
|
||||
try:
|
||||
details_line = [
|
||||
line.strip()
|
||||
for line in error_msg.split("\n")
|
||||
if "details =" in line
|
||||
][0]
|
||||
error_msg = details_line.split("details =")[1].strip(' "')
|
||||
except (IndexError, AttributeError):
|
||||
# Fall back to full message if parsing fails
|
||||
pass
|
||||
|
||||
return status_code, error_msg
|
||||
|
||||
# For any other type of exception
|
||||
return None, str(exception)
|
||||
|
||||
def _is_collection_not_found_error(self, exception):
|
||||
"""
|
||||
Check if the exception is due to collection not found, supporting both HTTP and gRPC
|
||||
"""
|
||||
status_code, error_msg = self._extract_error_message(exception)
|
||||
|
||||
# HTTP error (404)
|
||||
if (
|
||||
status_code == 404
|
||||
and "Collection" in error_msg
|
||||
and "doesn't exist" in error_msg
|
||||
):
|
||||
return True
|
||||
|
||||
# gRPC error (NOT_FOUND status)
|
||||
if (
|
||||
isinstance(exception, grpc.RpcError)
|
||||
and exception.code() == grpc.StatusCode.NOT_FOUND
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_dimension_mismatch_error(self, exception):
|
||||
"""
|
||||
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
|
||||
"""
|
||||
status_code, error_msg = self._extract_error_message(exception)
|
||||
|
||||
# Common patterns in both HTTP and gRPC
|
||||
return (
|
||||
"Vector dimension error" in error_msg
|
||||
or "dimensions mismatch" in error_msg
|
||||
or "invalid vector size" in error_msg
|
||||
)
|
||||
|
||||
def _create_multi_tenant_collection_if_not_exists(
|
||||
self, mt_collection_name: str, dimension: int = 384
|
||||
):
|
||||
"""
|
||||
Creates a collection with multi-tenancy configuration if it doesn't exist.
|
||||
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
|
||||
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
|
||||
"""
|
||||
try:
|
||||
# Try to create the collection directly - will fail if it already exists
|
||||
self.client.create_collection(
|
||||
collection_name=mt_collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension,
|
||||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16, # Enable per-tenant indexing
|
||||
m=0,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
)
|
||||
|
||||
# Create tenant ID payload index
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
field_name="tenant_id",
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
wait=True,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||
)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
# Check for the specific error indicating collection already exists
|
||||
status_code, error_msg = self._extract_error_message(e)
|
||||
|
||||
# HTTP status code 409 or gRPC ALREADY_EXISTS
|
||||
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
|
||||
isinstance(e, grpc.RpcError)
|
||||
and e.code() == grpc.StatusCode.ALREADY_EXISTS
|
||||
):
|
||||
if "already exists" in error_msg:
|
||||
log.debug(f"Collection {mt_collection_name} already exists")
|
||||
return
|
||||
# If it's not an already exists error, re-raise
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _create_points(self, items: list[VectorItem], tenant_id: str):
|
||||
"""
|
||||
Create point structs from vector items with tenant ID.
|
||||
"""
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Check if a logical collection exists by checking for any points with the tenant ID.
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Try directly querying - most of the time collection should exist
|
||||
response = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=models.Filter(must=[tenant_filter]),
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Collection exists with this tenant ID if there are points
|
||||
return len(response.points) > 0
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(f"Collection {mt_collection} doesn't exist")
|
||||
return False
|
||||
else:
|
||||
# For other API errors, log and return False
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error: {error_msg}")
|
||||
return False
|
||||
except Exception as e:
|
||||
# For any other errors, log and return False
|
||||
log.debug(f"Error checking collection {mt_collection}: {e}")
|
||||
return False
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Delete vectors by ID or filter from a collection with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
must_conditions = [tenant_filter]
|
||||
should_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
should_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
)
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to delete directly - most of the time collection should exist
|
||||
update_result = self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=must_conditions, should=should_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
return update_result
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, nothing to delete"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, re-raise
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for the nearest neighbor items based on the vectors with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get the vector dimension from the query vector
|
||||
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
|
||||
|
||||
try:
|
||||
# Try the search operation directly - most of the time collection should exist
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
# Ensure vector dimensions match the collection
|
||||
collection_dim = self.client.get_collection(
|
||||
mt_collection
|
||||
).config.params.vectors.size
|
||||
|
||||
if collection_dim != dimension:
|
||||
if collection_dim < dimension:
|
||||
vectors = [vector[:collection_dim] for vector in vectors]
|
||||
else:
|
||||
vectors = [
|
||||
vector + [0] * (collection_dim - dimension)
|
||||
for vector in vectors
|
||||
]
|
||||
|
||||
# Search with tenant filter
|
||||
prefetch_query = models.Prefetch(
|
||||
filter=models.Filter(must=[tenant_filter]),
|
||||
limit=NO_LIMIT,
|
||||
)
|
||||
query_response = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query=vectors[0],
|
||||
prefetch=prefetch_query,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
# qdrant distance is [-1, 1], normalize to [0, 1]
|
||||
distances=[
|
||||
[(point.score + 1.0) / 2.0 for point in query_response.points]
|
||||
],
|
||||
)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, search returns None"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and return None
|
||||
log.exception(f"Error searching collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
"""
|
||||
Query points with filters and tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Set default limit if not provided
|
||||
if limit is None:
|
||||
limit = NO_LIMIT
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
# Create metadata filters
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
# Combine tenant filter with metadata filters
|
||||
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
||||
|
||||
try:
|
||||
# Try the query directly - most of the time collection should exist
|
||||
points = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=combined_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(points.points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, query returns None"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and re-raise
|
||||
log.exception(f"Error querying collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""
|
||||
Get all items in a collection with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get points directly - most of the time collection should exist
|
||||
points = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=models.Filter(must=[tenant_filter]),
|
||||
limit=NO_LIMIT,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(points.points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and return None
|
||||
log.exception(f"Error getting collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def _handle_operation_with_error_retry(
|
||||
self, operation_name, mt_collection, points, dimension
|
||||
):
|
||||
"""
|
||||
Private helper to handle common error cases for insert and upsert operations.
|
||||
|
||||
Args:
|
||||
operation_name: 'insert' or 'upsert'
|
||||
mt_collection: The multi-tenant collection name
|
||||
points: The vector points to insert/upsert
|
||||
dimension: The dimension of the vectors
|
||||
|
||||
Returns:
|
||||
The operation result (for upsert) or None (for insert)
|
||||
"""
|
||||
try:
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
# Handle collection not found
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.info(
|
||||
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
|
||||
)
|
||||
# Create collection with correct dimensions from our vectors
|
||||
self._create_multi_tenant_collection_if_not_exists(
|
||||
mt_collection_name=mt_collection, dimension=dimension
|
||||
)
|
||||
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
|
||||
# Handle dimension mismatch
|
||||
elif self._is_dimension_mismatch_error(e):
|
||||
# For dimension errors, the collection must exist, so get its configuration
|
||||
mt_collection_info = self.client.get_collection(mt_collection)
|
||||
existing_size = mt_collection_info.config.params.vectors.size
|
||||
|
||||
log.info(
|
||||
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
|
||||
)
|
||||
|
||||
if existing_size < dimension:
|
||||
# Truncate vectors to fit
|
||||
log.info(
|
||||
f"Truncating vectors from {dimension} to {existing_size} dimensions"
|
||||
)
|
||||
points = [
|
||||
PointStruct(
|
||||
id=point.id,
|
||||
vector=point.vector[:existing_size],
|
||||
payload=point.payload,
|
||||
)
|
||||
for point in points
|
||||
]
|
||||
elif existing_size > dimension:
|
||||
# Pad vectors with zeros
|
||||
log.info(
|
||||
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
|
||||
)
|
||||
points = [
|
||||
PointStruct(
|
||||
id=point.id,
|
||||
vector=point.vector
|
||||
+ [0] * (existing_size - len(point.vector)),
|
||||
payload=point.payload,
|
||||
)
|
||||
for point in points
|
||||
]
|
||||
# Try operation again with adjusted dimensions
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
else:
|
||||
# Not a known error we can handle, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unhandled Qdrant error: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, re-raise
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
"""
|
||||
Insert items with tenant ID.
|
||||
"""
|
||||
if not self.client or not items:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get dimensions from the actual vectors
|
||||
dimension = len(items[0]["vector"]) if items else None
|
||||
|
||||
# Create points with tenant ID
|
||||
points = self._create_points(items, tenant_id)
|
||||
|
||||
# Handle the operation with error retry
|
||||
return self._handle_operation_with_error_retry(
|
||||
"insert", mt_collection, points, dimension
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
"""
|
||||
Upsert items with tenant ID.
|
||||
"""
|
||||
if not self.client or not items:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get dimensions from the actual vectors
|
||||
dimension = len(items[0]["vector"]) if items else None
|
||||
|
||||
# Create points with tenant ID
|
||||
points = self._create_points(items, tenant_id)
|
||||
|
||||
# Handle the operation with error retry
|
||||
return self._handle_operation_with_error_retry(
|
||||
"upsert", mt_collection, points, dimension
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the database by deleting all collections.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
field_conditions = [tenant_filter]
|
||||
|
||||
update_result = self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
if self.client.get_collection(mt_collection).points_count == 0:
|
||||
self.client.delete_collection(mt_collection)
|
||||
|
||||
return update_result
|
||||
55
backend/open_webui/retrieval/vector/factory.py
Normal file
55
backend/open_webui/retrieval/vector/factory.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from open_webui.retrieval.vector.main import VectorDBBase
|
||||
from open_webui.retrieval.vector.type import VectorType
|
||||
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
|
||||
|
||||
|
||||
class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector(vector_type: str) -> VectorDBBase:
|
||||
"""
|
||||
get vector db instance by vector type
|
||||
"""
|
||||
match vector_type:
|
||||
case VectorType.MILVUS:
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
return MilvusClient()
|
||||
case VectorType.QDRANT:
|
||||
if ENABLE_QDRANT_MULTITENANCY_MODE:
|
||||
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
|
||||
QdrantClient,
|
||||
)
|
||||
|
||||
return QdrantClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
return QdrantClient()
|
||||
case VectorType.PINECONE:
|
||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||
|
||||
return PineconeClient()
|
||||
case VectorType.OPENSEARCH:
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
return OpenSearchClient()
|
||||
case VectorType.PGVECTOR:
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
return PgvectorClient()
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from open_webui.retrieval.vector.dbs.elasticsearch import (
|
||||
ElasticsearchClient,
|
||||
)
|
||||
|
||||
return ElasticsearchClient()
|
||||
case VectorType.CHROMA:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
return ChromaClient()
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||
|
||||
|
||||
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
|
|
@ -17,3 +18,69 @@ class GetResult(BaseModel):
|
|||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
Abstract base class for all vector database backends.
|
||||
|
||||
Implementations of this class provide methods for collection management,
|
||||
vector insertion, deletion, similarity search, and metadata filtering.
|
||||
|
||||
Any custom vector database integration must inherit from this class and
|
||||
implement all abstract methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if the collection exists in the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection from the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert a list of vector items into a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert or update vector items in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""Query vectors from a collection using metadata filter."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Retrieve all vectors from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by ID or filter from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by removing all collections or those matching a condition."""
|
||||
pass
|
||||
|
|
|
|||
11
backend/open_webui/retrieval/vector/type.py
Normal file
11
backend/open_webui/retrieval/vector/type.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from enum import StrEnum
|
||||
|
||||
|
||||
class VectorType(StrEnum):
|
||||
MILVUS = "milvus"
|
||||
QDRANT = "qdrant"
|
||||
CHROMA = "chroma"
|
||||
PINECONE = "pinecone"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
OPENSEARCH = "opensearch"
|
||||
PGVECTOR = "pgvector"
|
||||
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from duckduckgo_search import DDGS
|
||||
from duckduckgo_search.exceptions import RatelimitException
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -22,29 +23,24 @@ def search_duckduckgo(
|
|||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
# Use the DDGS context manager to create a DDGS object
|
||||
search_results = []
|
||||
with DDGS() as ddgs:
|
||||
# Use the ddgs.text() method to perform the search
|
||||
ddgs_gen = ddgs.text(
|
||||
query, safesearch="moderate", max_results=count, backend="api"
|
||||
)
|
||||
# Check if there are search results
|
||||
if ddgs_gen:
|
||||
# Convert the search results into a list
|
||||
search_results = [r for r in ddgs_gen]
|
||||
|
||||
# Create an empty list to store the SearchResult objects
|
||||
results = []
|
||||
# Iterate over each search result
|
||||
for result in search_results:
|
||||
# Create a SearchResult object and append it to the results list
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
try:
|
||||
search_results = ddgs.text(
|
||||
query, safesearch="moderate", max_results=count, backend="lite"
|
||||
)
|
||||
)
|
||||
except RatelimitException as e:
|
||||
log.error(f"RatelimitException: {e}")
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
search_results = get_filtered_results(search_results, filter_list)
|
||||
|
||||
# Return the list of search results
|
||||
return results
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
|
|
|||
47
backend/open_webui/retrieval/web/external.py
Normal file
47
backend/open_webui/retrieval/web/external.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_external(
|
||||
external_url: str,
|
||||
external_api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
response = requests.post(
|
||||
external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {external_api_key}",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"count": count,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("link"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
return []
|
||||
49
backend/open_webui/retrieval/web/firecrawl.py
Normal file
49
backend/open_webui/retrieval/web/firecrawl.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_firecrawl(
|
||||
firecrawl_url: str,
|
||||
firecrawl_api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
|
||||
response = requests.post(
|
||||
firecrawl_search_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {firecrawl_api_key}",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"limit": count,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json().get("data", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("url"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
return []
|
||||
87
backend/open_webui/retrieval/web/perplexity.py
Normal file
87
backend/open_webui/retrieval/web/perplexity.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
import requests
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_perplexity(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Perplexity API key
|
||||
query (str): The query to search for
|
||||
count (int): Maximum number of results to return
|
||||
|
||||
"""
|
||||
|
||||
# Handle PersistentConfig object
|
||||
if hasattr(api_key, "__str__"):
|
||||
api_key = str(api_key)
|
||||
|
||||
try:
|
||||
url = "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"model": "sonar",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide factual information with citations.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
"temperature": 0.2, # Lower temperature for more factual responses
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Make the API request
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
# Parse the JSON response
|
||||
json_response = response.json()
|
||||
|
||||
# Extract citations from the response
|
||||
citations = json_response.get("citations", [])
|
||||
|
||||
# Create search results from citations
|
||||
results = []
|
||||
for i, citation in enumerate(citations[:count]):
|
||||
# Extract content from the response to use as snippet
|
||||
content = ""
|
||||
if "choices" in json_response and json_response["choices"]:
|
||||
if i == 0:
|
||||
content = json_response["choices"][0]["message"]["content"]
|
||||
|
||||
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
|
||||
results.append(result)
|
||||
|
||||
if filter_list:
|
||||
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error searching with Perplexity API: {e}")
|
||||
return []
|
||||
|
|
@ -42,7 +42,9 @@ def search_searchapi(
|
|||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -42,7 +42,9 @@ def search_serpapi(
|
|||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
|
|||
60
backend/open_webui/retrieval/web/sougou.py
Normal file
60
backend/open_webui/retrieval/web/sougou.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import logging
|
||||
import json
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_sougou(
|
||||
sougou_api_sid: str,
|
||||
sougou_api_sk: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> List[SearchResult]:
|
||||
from tencentcloud.common.common_client import CommonClient
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
|
||||
try:
|
||||
cred = credential.Credential(sougou_api_sid, sougou_api_sk)
|
||||
http_profile = HttpProfile()
|
||||
http_profile.endpoint = "tms.tencentcloudapi.com"
|
||||
client_profile = ClientProfile()
|
||||
client_profile.http_profile = http_profile
|
||||
params = json.dumps({"Query": query, "Cnt": 20})
|
||||
common_client = CommonClient(
|
||||
"tms", "2020-12-29", cred, "", profile=client_profile
|
||||
)
|
||||
results = [
|
||||
json.loads(page)
|
||||
for page in common_client.call_json("SearchPro", json.loads(params))[
|
||||
"Response"
|
||||
]["Pages"]
|
||||
]
|
||||
sorted_results = sorted(
|
||||
results, key=lambda x: x.get("scour", 0.0), reverse=True
|
||||
)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.get("url"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("passage"),
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
except TencentCloudSDKException as err:
|
||||
log.error(f"Error in Sougou search: {err}")
|
||||
return []
|
||||
|
|
@ -1,32 +1,45 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_tavily(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
# **kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Tavily Search API key
|
||||
query (str): The query to search for
|
||||
count (int): The maximum number of results to return
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
data = {"query": query, "max_results": count}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
raw_search_results = json_response.get("results", [])
|
||||
results = json_response.get("results", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
|
|
@ -34,5 +47,5 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
|||
title=result.get("title", ""),
|
||||
snippet=result.get("content"),
|
||||
)
|
||||
for result in raw_search_results[:count]
|
||||
for result in results
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,45 @@
|
|||
import socket
|
||||
import urllib.parse
|
||||
import validators
|
||||
from typing import Union, Sequence, Iterator
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
Literal,
|
||||
)
|
||||
import aiohttp
|
||||
import certifi
|
||||
import validators
|
||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
PLAYWRIGHT_WS_URL,
|
||||
PLAYWRIGHT_TIMEOUT,
|
||||
WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
TAVILY_API_KEY,
|
||||
TAVILY_EXTRACT_DEPTH,
|
||||
EXTERNAL_WEB_LOADER_URL,
|
||||
EXTERNAL_WEB_LOADER_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
|
@ -65,9 +91,472 @@ def resolve_hostname(hostname):
|
|||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", "No description found.")
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
return metadata
|
||||
|
||||
|
||||
def verify_ssl_cert(url: str) -> bool:
|
||||
"""Verify SSL certificate for the given URL."""
|
||||
if not url.startswith("https://"):
|
||||
return True
|
||||
|
||||
try:
|
||||
hostname = url.split("://")[-1].split("/")[0]
|
||||
context = ssl.create_default_context(cafile=certifi.where())
|
||||
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
|
||||
s.connect((hostname, 443))
|
||||
return True
|
||||
except ssl.SSLError:
|
||||
return False
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RateLimitMixin:
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
|
||||
class URLProcessingMixin:
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
"""Verify SSL certificate for a URL."""
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths,
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
mode: Literal["crawl", "scrape", "map"] = "scrape",
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
"""Concurrent document loader for FireCrawl operations.
|
||||
|
||||
Executes multiple FireCrawlLoader instances concurrently using thread pooling
|
||||
to improve bulk processing efficiency.
|
||||
Args:
|
||||
web_paths: List of URLs/paths to process.
|
||||
verify_ssl: If True, verify SSL certificates.
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
api_key: API key for FireCrawl service. Defaults to None
|
||||
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||
mode: Operation mode selection:
|
||||
- 'crawl': Website crawling mode (default)
|
||||
- 'scrape': Direct page scraping
|
||||
- 'map': Site map generation
|
||||
proxy: Proxy override settings for the FireCrawl API.
|
||||
params: The parameters to pass to the Firecrawl API.
|
||||
Examples include crawlerOptions.
|
||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||
"""
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
self.web_paths = web_paths
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.trust_env = trust_env
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.mode = mode
|
||||
self.params = params
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents concurrently using FireCrawl."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
for document in loader.lazy_load():
|
||||
if not document.metadata.get("source"):
|
||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
|
||||
async def alazy_load(self):
|
||||
"""Async version of lazy_load."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
if not document.metadata.get("source"):
|
||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
|
||||
|
||||
class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
continue_on_failure: bool = True,
|
||||
requests_per_second: Optional[float] = None,
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
|
||||
|
||||
Args:
|
||||
web_paths: List of URLs/paths to process.
|
||||
api_key: The Tavily API key.
|
||||
extract_depth: Depth of extraction ("basic" or "advanced").
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
verify_ssl: If True, verify SSL certificates.
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
proxy: Optional proxy configuration.
|
||||
"""
|
||||
# Initialize proxy configuration if using environment variables
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
# Store parameters for creating TavilyLoader instances
|
||||
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
self.api_key = api_key
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.verify_ssl = verify_ssl
|
||||
self.trust_env = trust_env
|
||||
self.proxy = proxy
|
||||
|
||||
# Add rate limiting
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents with rate limiting support, delegating to TavilyLoader."""
|
||||
valid_urls = []
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
urls=valid_urls,
|
||||
api_key=self.api_key,
|
||||
extract_depth=self.extract_depth,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error extracting content from URLs: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async version with rate limiting and SSL verification."""
|
||||
valid_urls = []
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
urls=valid_urls,
|
||||
api_key=self.api_key,
|
||||
extract_depth=self.extract_depth,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading URLs: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
|
||||
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||
|
||||
Attributes:
|
||||
web_paths (List[str]): List of URLs to load.
|
||||
verify_ssl (bool): If True, verify SSL certificates.
|
||||
trust_env (bool): If True, use proxy settings from environment variables.
|
||||
requests_per_second (Optional[float]): Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
headless (bool): If True, the browser will run in headless mode.
|
||||
proxy (dict): Proxy override settings for the Playwright session.
|
||||
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
|
||||
playwright_timeout (Optional[int]): Maximum operation time in milliseconds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: List[str],
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
headless: bool = True,
|
||||
remove_selectors: Optional[List[str]] = None,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
playwright_ws_url: Optional[str] = None,
|
||||
playwright_timeout: Optional[int] = 10000,
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||
super().__init__(
|
||||
urls=web_paths,
|
||||
continue_on_failure=continue_on_failure,
|
||||
headless=headless if playwright_ws_url is None else False,
|
||||
remove_selectors=remove_selectors,
|
||||
proxy=proxy,
|
||||
)
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.playwright_ws_url = playwright_ws_url
|
||||
self.trust_env = trust_env
|
||||
self.playwright_timeout = playwright_timeout
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Safely load URLs synchronously with support for remote browser."""
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
with sync_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
page = browser.new_page()
|
||||
response = page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = self.evaluator.evaluate(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
browser.close()
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Safely load URLs asynchronously with support for remote browser."""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = await p.chromium.launch(
|
||||
headless=self.headless, proxy=self.proxy
|
||||
)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
page = await browser.new_page()
|
||||
response = await page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = await self.evaluator.evaluate_async(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
await browser.close()
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
def __init__(self, trust_env: bool = False, *args, **kwargs):
|
||||
"""Initialize SafeWebBaseLoader
|
||||
Args:
|
||||
trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
||||
using http(s)_proxy environment variables. Defaults to False.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trust_env = trust_env
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||
for i in range(retries):
|
||||
try:
|
||||
kwargs: Dict = dict(
|
||||
headers=self.session.headers,
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
)
|
||||
if not self.session.verify:
|
||||
kwargs["ssl"] = False
|
||||
|
||||
async with session.get(
|
||||
url,
|
||||
**(self.requests_kwargs | kwargs),
|
||||
) as response:
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
if i == retries - 1:
|
||||
raise
|
||||
else:
|
||||
log.warning(
|
||||
f"Error fetching {url} with attempt "
|
||||
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(cooldown * backoff**i)
|
||||
raise ValueError("retry count exceeded")
|
||||
|
||||
def _unpack_fetch_results(
|
||||
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Unpack fetch results into BeautifulSoup objects."""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
url = urls[i]
|
||||
if parser is None:
|
||||
if url.endswith(".xml"):
|
||||
parser = "xml"
|
||||
else:
|
||||
parser = self.default_parser
|
||||
self._check_parser(parser)
|
||||
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||
return final_results
|
||||
|
||||
async def ascrape_all(
|
||||
self, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Async fetch all urls, then return soups for all results."""
|
||||
results = await self.fetch_all(urls)
|
||||
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||
for path in self.web_paths:
|
||||
|
|
@ -76,33 +565,86 @@ class SafeWebBaseLoader(WebBaseLoader):
|
|||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
|
||||
# Build metadata
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
metadata = extract_metadata(soup, path)
|
||||
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.error(f"Error loading {path}: {e}")
|
||||
log.exception(f"Error loading {path}: {e}")
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
results = await self.ascrape_all(self.web_paths)
|
||||
for path, soup in zip(self.web_paths, results):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
trust_env: bool = False,
|
||||
):
|
||||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
return SafeWebBaseLoader(
|
||||
safe_urls,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=requests_per_second,
|
||||
continue_on_failure=True,
|
||||
)
|
||||
web_loader_args = {
|
||||
"web_paths": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True,
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
|
||||
WebLoaderClass = SafeWebBaseLoader
|
||||
if WEB_LOADER_ENGINE.value == "playwright":
|
||||
WebLoaderClass = SafePlaywrightURLLoader
|
||||
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
|
||||
if PLAYWRIGHT_WS_URL.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
WebLoaderClass = SafeFireCrawlLoader
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "tavily":
|
||||
WebLoaderClass = SafeTavilyLoader
|
||||
web_loader_args["api_key"] = TAVILY_API_KEY.value
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "external":
|
||||
WebLoaderClass = ExternalWebLoader
|
||||
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
|
||||
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
|
||||
|
||||
if WebLoaderClass:
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug(
|
||||
"Using WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
|
||||
return web_loader
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
|
||||
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
|
||||
)
|
||||
|
|
|
|||
87
backend/open_webui/retrieval/web/yacy.py
Normal file
87
backend/open_webui/retrieval/web/yacy.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_yacy(
|
||||
query_url: str,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search a Yacy instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function accepts username and password for authenticating to Yacy.
|
||||
|
||||
Args:
|
||||
query_url (str): The base URL of the Yacy server.
|
||||
username (str): Optional YaCy username.
|
||||
password (str): Optional YaCy password.
|
||||
query (str): The search term or question to find in the Yacy database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
|
||||
# Use authentication if either username or password is set
|
||||
yacy_auth = None
|
||||
if username or password:
|
||||
yacy_auth = HTTPDigestAuth(username, password)
|
||||
|
||||
params = {
|
||||
"query": query,
|
||||
"contentdom": "text",
|
||||
"resource": "global",
|
||||
"maximumRecords": count,
|
||||
"nav": "none",
|
||||
}
|
||||
|
||||
# Check if provided a json API URL
|
||||
if not query_url.endswith("yacysearch.json"):
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.rstrip("/") + "/yacysearch.json"
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
auth=yacy_auth,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("channels", [{}])[0].get("items", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
|
|
@ -7,6 +7,9 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
|
|
@ -17,6 +20,7 @@ from fastapi import (
|
|||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
|
|
@ -33,10 +37,13 @@ from open_webui.config import (
|
|||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
CACHE_DIR,
|
||||
WHISPER_LANGUAGE,
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
|
|
@ -47,13 +54,15 @@ from open_webui.env import (
|
|||
router = APIRouter()
|
||||
|
||||
# Constants
|
||||
MAX_FILE_SIZE_MB = 25
|
||||
MAX_FILE_SIZE_MB = 20
|
||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
AZURE_MAX_FILE_SIZE_MB = 200
|
||||
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
|
@ -67,27 +76,47 @@ from pydub import AudioSegment
|
|||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
def is_audio_conversion_required(file_path):
|
||||
"""
|
||||
Check if the given audio file needs conversion to mp3.
|
||||
"""
|
||||
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
):
|
||||
try:
|
||||
info = mediainfo(file_path)
|
||||
codec_name = info.get("codec_name", "").lower()
|
||||
codec_type = info.get("codec_type", "").lower()
|
||||
codec_tag_string = info.get("codec_tag_string", "").lower()
|
||||
|
||||
if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
|
||||
# File is AAC/mp4a audio, recommend mp3 conversion
|
||||
return True
|
||||
|
||||
# If the codec name is in the supported formats
|
||||
if codec_name in SUPPORTED_FORMATS:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(f"Error getting audio format: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
def convert_audio_to_mp3(file_path):
|
||||
"""Convert audio file to mp3 format."""
|
||||
try:
|
||||
output_path = os.path.splitext(file_path)[0] + ".mp3"
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio.export(output_path, format="mp3")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
log.error(f"Error converting audio file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
|
|
@ -130,6 +159,7 @@ class TTSConfigForm(BaseModel):
|
|||
VOICE: str
|
||||
SPLIT_ON: str
|
||||
AZURE_SPEECH_REGION: str
|
||||
AZURE_SPEECH_BASE_URL: str
|
||||
AZURE_SPEECH_OUTPUT_FORMAT: str
|
||||
|
||||
|
||||
|
|
@ -140,6 +170,11 @@ class STTConfigForm(BaseModel):
|
|||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
DEEPGRAM_API_KEY: str
|
||||
AZURE_API_KEY: str
|
||||
AZURE_REGION: str
|
||||
AZURE_LOCALES: str
|
||||
AZURE_BASE_URL: str
|
||||
AZURE_MAX_SPEAKERS: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
|
|
@ -159,6 +194,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
|||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
|
|
@ -168,6 +204,11 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
|||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -184,6 +225,9 @@ async def update_audio_config(
|
|||
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
|
||||
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
||||
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
||||
request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
|
||||
form_data.tts.AZURE_SPEECH_BASE_URL
|
||||
)
|
||||
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
||||
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
||||
)
|
||||
|
|
@ -194,6 +238,13 @@ async def update_audio_config(
|
|||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
|
||||
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
|
||||
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
|
||||
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
|
||||
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
||||
form_data.stt.AZURE_MAX_SPEAKERS
|
||||
)
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
|
|
@ -210,6 +261,7 @@ async def update_audio_config(
|
|||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
|
|
@ -219,6 +271,11 @@ async def update_audio_config(
|
|||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -265,8 +322,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
|
|
@ -284,6 +343,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -309,7 +369,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
|
@ -323,7 +383,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
|
|
@ -336,6 +399,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
"Content-Type": "application/json",
|
||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -360,7 +424,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
|
@ -371,7 +435,8 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
|
||||
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
|
||||
language = request.app.state.config.TTS_VOICE
|
||||
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
|
||||
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
|
|
@ -380,15 +445,20 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
(base_url or f"https://{region}.tts.speech.microsoft.com")
|
||||
+ "/cognitiveservices/v1",
|
||||
headers={
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": output_format,
|
||||
},
|
||||
data=data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -413,7 +483,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
|
@ -457,12 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
def transcription_handler(request, file_path, metadata):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
|
|
@ -470,7 +541,12 @@ def transcribe(request: Request, file_path):
|
|||
)
|
||||
|
||||
model = request.app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
segments, info = model.transcribe(
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
|
|
@ -487,11 +563,6 @@ def transcribe(request: Request, file_path):
|
|||
log.debug(data)
|
||||
return data
|
||||
elif request.app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
||||
# Convert MP4 audio file to WAV format
|
||||
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
|
|
@ -500,7 +571,14 @@ def transcribe(request: Request, file_path):
|
|||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
data={
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
|
@ -589,34 +667,254 @@ def transcribe(request: Request, file_path):
|
|||
detail = f"External: {e}"
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "azure":
|
||||
# Check file exists and size
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=400, detail="Audio file not found")
|
||||
|
||||
# Check file size (Azure has a larger limit of 200MB)
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > AZURE_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
|
||||
)
|
||||
|
||||
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
|
||||
region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
|
||||
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
|
||||
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
|
||||
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
|
||||
|
||||
# IF NO LOCALES, USE DEFAULTS
|
||||
if len(locales) < 2:
|
||||
locales = [
|
||||
"en-US",
|
||||
"es-ES",
|
||||
"es-MX",
|
||||
"fr-FR",
|
||||
"hi-IN",
|
||||
"it-IT",
|
||||
"de-DE",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"pt-BR",
|
||||
"zh-CN",
|
||||
]
|
||||
locales = ",".join(locales)
|
||||
|
||||
if not api_key or not region:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Azure API key is required for Azure STT",
|
||||
)
|
||||
|
||||
r = None
|
||||
try:
|
||||
# Prepare the request
|
||||
data = {
|
||||
"definition": json.dumps(
|
||||
{
|
||||
"locales": locales.split(","),
|
||||
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
|
||||
}
|
||||
if locales
|
||||
else {}
|
||||
)
|
||||
}
|
||||
|
||||
url = (
|
||||
base_url or f"https://{region}.api.cognitive.microsoft.com"
|
||||
) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
|
||||
|
||||
# Use context manager to ensure file is properly closed
|
||||
with open(file_path, "rb") as audio_file:
|
||||
r = requests.post(
|
||||
url=url,
|
||||
files={"audio": audio_file},
|
||||
data=data,
|
||||
headers={
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
# Extract transcript from response
|
||||
if not response.get("combinedPhrases"):
|
||||
raise ValueError("No transcription found in response")
|
||||
|
||||
# Get the full transcript from combinedPhrases
|
||||
transcript = response["combinedPhrases"][0].get("text", "").strip()
|
||||
if not transcript:
|
||||
raise ValueError("Empty transcript in response")
|
||||
|
||||
data = {"text": transcript}
|
||||
|
||||
# Save transcript to json file (consistent with other providers)
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
log.debug(data)
|
||||
return data
|
||||
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
log.exception("Error parsing Azure response")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to parse Azure response: {str(e)}",
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r is not None and r.status_code != 200:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status_code", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
file_path = convert_audio_to_mp3(file_path)
|
||||
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# Always produce a list of chunk paths (could be one entry if small)
|
||||
try:
|
||||
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
|
||||
print(f"Chunk paths: {chunk_paths}")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
results = []
|
||||
try:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
for future in futures:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as transcribe_exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error transcribing chunk: {transcribe_exc}",
|
||||
)
|
||||
finally:
|
||||
# Clean up only the temporary chunks, never the original file
|
||||
for chunk_path in chunk_paths:
|
||||
if chunk_path != file_path and os.path.isfile(chunk_path):
|
||||
try:
|
||||
os.remove(chunk_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"text": " ".join([result["text"] for result in results]),
|
||||
}
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
id = os.path.splitext(os.path.basename(file_path))[
|
||||
0
|
||||
] # Handles names with multiple dots
|
||||
file_dir = os.path.dirname(file_path)
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
||||
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
||||
audio.export(compressed_path, format="opus", bitrate="32k")
|
||||
log.debug(f"Compressed audio to {compressed_path}")
|
||||
|
||||
if (
|
||||
os.path.getsize(compressed_path) > MAX_FILE_SIZE
|
||||
): # Still larger than MAX_FILE_SIZE after compression
|
||||
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
|
||||
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
|
||||
audio.export(compressed_path, format="mp3", bitrate="32k")
|
||||
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
|
||||
|
||||
return compressed_path
|
||||
else:
|
||||
return file_path
|
||||
|
||||
|
||||
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
|
||||
"""
|
||||
Splits audio into chunks not exceeding max_bytes.
|
||||
Returns a list of chunk file paths. If audio fits, returns list with original path.
|
||||
"""
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size <= max_bytes:
|
||||
return [file_path] # Nothing to split
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
duration_ms = len(audio)
|
||||
orig_size = file_size
|
||||
|
||||
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
|
||||
chunks = []
|
||||
start = 0
|
||||
i = 0
|
||||
|
||||
base, _ = os.path.splitext(file_path)
|
||||
|
||||
while start < duration_ms:
|
||||
end = min(start + approx_chunk_ms, duration_ms)
|
||||
chunk = audio[start:end]
|
||||
chunk_path = f"{base}_chunk_{i}.{format}"
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
# Reduce chunk duration if still too large
|
||||
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
|
||||
end = start + ((end - start) // 2)
|
||||
chunk = audio[start:end]
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
if os.path.getsize(chunk_path) > max_bytes:
|
||||
os.remove(chunk_path)
|
||||
raise Exception("Audio chunk cannot be reduced below max file size.")
|
||||
|
||||
chunks.append(chunk_path)
|
||||
start = end
|
||||
i += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@router.post("/transcriptions")
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
||||
SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
|
||||
if not (
|
||||
file.content_type.startswith("audio/")
|
||||
or file.content_type in SUPPORTED_CONTENT_TYPES
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
|
|
@ -637,19 +935,18 @@ def transcription(
|
|||
f.write(contents)
|
||||
|
||||
try:
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
metadata = None
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
if language:
|
||||
metadata = {"language": language}
|
||||
|
||||
result = transcribe(request, file_path, metadata)
|
||||
|
||||
return {
|
||||
**result,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
|
||||
data = transcribe(request, file_path)
|
||||
file_path = file_path.split("/")[-1]
|
||||
return {**data, "filename": file_path}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
|
@ -670,7 +967,22 @@ def transcription(
|
|||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
|
|
@ -701,14 +1013,37 @@ def get_available_voices(request) -> dict:
|
|||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
|
|
@ -720,7 +1055,10 @@ def get_available_voices(request) -> dict:
|
|||
elif request.app.state.config.TTS_ENGINE == "azure":
|
||||
try:
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
|
||||
url = (
|
||||
base_url or f"https://{region}.tts.speech.microsoft.com"
|
||||
) + "/cognitiveservices/voices/list"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,40 +19,45 @@ from open_webui.models.auths import (
|
|||
UserResponse,
|
||||
)
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -73,31 +78,36 @@ class SessionUserResponse(Token, UserResponse):
|
|||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_token = get_http_authorization_cred(auth_header)
|
||||
token = auth_token.credentials
|
||||
data = decode_token(token)
|
||||
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
if data:
|
||||
expires_at = data.get("exp")
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
if (expires_at is not None) and int(time.time()) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
|
|
@ -178,6 +188,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
||||
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
||||
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
||||
LDAP_VALIDATE_CERT = (
|
||||
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
|
||||
)
|
||||
LDAP_CIPHERS = (
|
||||
request.app.state.config.LDAP_CIPHERS
|
||||
if request.app.state.config.LDAP_CIPHERS
|
||||
|
|
@ -189,14 +202,14 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
|
||||
try:
|
||||
tls = Tls(
|
||||
validate=CERT_REQUIRED,
|
||||
validate=LDAP_VALIDATE_CERT,
|
||||
version=PROTOCOL_TLS,
|
||||
ca_certs_file=LDAP_CA_CERT_FILE,
|
||||
ciphers=LDAP_CIPHERS,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred on TLS: {str(e)}")
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"TLS configuration error: {str(e)}")
|
||||
raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
|
||||
|
||||
try:
|
||||
server = Server(
|
||||
|
|
@ -211,7 +224,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
LDAP_APP_DN,
|
||||
LDAP_APP_PASSWORD,
|
||||
auto_bind="NONE",
|
||||
authentication="SIMPLE",
|
||||
authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
|
||||
)
|
||||
if not connection_app.bind():
|
||||
raise HTTPException(400, detail="Application account bind failed")
|
||||
|
|
@ -226,14 +239,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
],
|
||||
)
|
||||
|
||||
if not search_success:
|
||||
if not search_success or not connection_app.entries:
|
||||
raise HTTPException(400, detail="User not found in the LDAP server")
|
||||
|
||||
entry = connection_app.entries[0]
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not mail or mail == "" or mail == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have mail.")
|
||||
email = entry[
|
||||
f"{LDAP_ATTRIBUTE_FOR_MAIL}"
|
||||
].value # retrieve the Attribute value
|
||||
if not email:
|
||||
raise HTTPException(400, "User does not have a valid email address.")
|
||||
elif isinstance(email, str):
|
||||
email = email.lower()
|
||||
elif isinstance(email, list):
|
||||
email = email[0].lower()
|
||||
else:
|
||||
email = str(email).lower()
|
||||
|
||||
cn = str(entry["cn"])
|
||||
user_dn = entry.entry_dn
|
||||
|
||||
|
|
@ -246,19 +268,24 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
authentication="SIMPLE",
|
||||
)
|
||||
if not connection_user.bind():
|
||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||
raise HTTPException(400, "Authentication failed.")
|
||||
|
||||
user = Users.get_user_by_email(mail)
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
if user_count == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
email=email,
|
||||
password=str(uuid.uuid4()),
|
||||
name=cn,
|
||||
role=role,
|
||||
)
|
||||
|
||||
if not user:
|
||||
|
|
@ -269,23 +296,38 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
except HTTPException:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"LDAP user creation error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="Internal error occurred during LDAP user creation."
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(mail)
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(
|
||||
request.app.state.config.JWT_EXPIRES_IN
|
||||
),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(
|
||||
expires_at, datetime.timezone.utc
|
||||
)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
|
|
@ -295,6 +337,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
|
|
@ -305,12 +348,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
|
||||
)
|
||||
raise HTTPException(400, "User record mismatch.")
|
||||
except Exception as e:
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"LDAP authentication error: {str(e)}")
|
||||
raise HTTPException(400, detail="LDAP authentication failed.")
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -324,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||
|
||||
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
trusted_name = trusted_email
|
||||
email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
name = email
|
||||
|
||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
trusted_name = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
|
||||
)
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||
|
||||
if not Users.get_user_by_email(email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
||||
)
|
||||
user = Auths.authenticate_user_by_trusted_header(trusted_email)
|
||||
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
||||
group_names = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
||||
).split(",")
|
||||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_user_groups_by_group_names(user.id, group_names)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
admin_password = "admin"
|
||||
|
|
@ -413,6 +462,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
not request.app.state.config.ENABLE_SIGNUP
|
||||
|
|
@ -427,6 +477,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
|
|
@ -437,14 +488,15 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
if Users.get_num_users() == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||
if len(form_data.password.encode("utf-8")) > 72:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||
)
|
||||
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
|
|
@ -484,6 +536,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
|
|
@ -497,6 +550,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
|
|
@ -511,7 +568,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Signup error: {str(err)}")
|
||||
raise HTTPException(500, detail="An internal error occurred during signup.")
|
||||
|
||||
|
||||
@router.get("/signout")
|
||||
|
|
@ -529,8 +587,14 @@ async def signout(request: Request, response: Response):
|
|||
logout_url = openid_data.get("end_session_endpoint")
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
},
|
||||
headers=response.headers,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -538,9 +602,25 @@ async def signout(request: Request, response: Response):
|
|||
detail="Failed to fetch OpenID configuration",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
log.error(f"OpenID signout error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to sign out from the OpenID provider.",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
},
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200, content={"status": True}, headers=response.headers
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -582,7 +662,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
|||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Add user error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="An internal error occurred while adding the user."
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -596,7 +679,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
|||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
|
|
@ -630,11 +713,16 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
|||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -645,11 +733,16 @@ class AdminConfig(BaseModel):
|
|||
ENABLE_API_KEY: bool
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
||||
API_KEY_ALLOWED_ENDPOINTS: str
|
||||
ENABLE_CHANNELS: bool
|
||||
DEFAULT_USER_ROLE: str
|
||||
JWT_EXPIRES_IN: str
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
ENABLE_MESSAGE_RATING: bool
|
||||
ENABLE_CHANNELS: bool
|
||||
ENABLE_NOTES: bool
|
||||
ENABLE_USER_WEBHOOKS: bool
|
||||
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
|
||||
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
|
||||
RESPONSE_WATERMARK: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/admin/config")
|
||||
|
|
@ -669,6 +762,7 @@ async def update_admin_config(
|
|||
)
|
||||
|
||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
|
|
@ -684,6 +778,17 @@ async def update_admin_config(
|
|||
)
|
||||
request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING
|
||||
|
||||
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
|
||||
|
||||
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
|
||||
form_data.PENDING_USER_OVERLAY_TITLE
|
||||
)
|
||||
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
|
||||
form_data.PENDING_USER_OVERLAY_CONTENT
|
||||
)
|
||||
|
||||
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
|
|
@ -691,11 +796,16 @@ async def update_admin_config(
|
|||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -711,6 +821,7 @@ class LdapServerConfig(BaseModel):
|
|||
search_filters: str = ""
|
||||
use_tls: bool = True
|
||||
certificate_path: Optional[str] = None
|
||||
validate_cert: bool = True
|
||||
ciphers: Optional[str] = "ALL"
|
||||
|
||||
|
||||
|
|
@ -728,6 +839,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
|
|||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
|
|
@ -750,11 +862,6 @@ async def update_ldap_server(
|
|||
if not value:
|
||||
raise HTTPException(400, detail=f"Required field {key} is empty")
|
||||
|
||||
if form_data.use_tls and not form_data.certificate_path:
|
||||
raise HTTPException(
|
||||
400, detail="TLS is enabled but certificate file path is missing"
|
||||
)
|
||||
|
||||
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
|
||||
request.app.state.config.LDAP_SERVER_HOST = form_data.host
|
||||
request.app.state.config.LDAP_SERVER_PORT = form_data.port
|
||||
|
|
@ -768,6 +875,7 @@ async def update_ldap_server(
|
|||
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
|
||||
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
|
||||
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
|
||||
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
|
||||
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
|
||||
|
||||
return {
|
||||
|
|
@ -782,6 +890,7 @@ async def update_ldap_server(
|
|||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ async def get_channel_messages(
|
|||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
|
|
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
|
|||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
name,
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
|
|
@ -302,6 +303,7 @@ async def post_new_message(
|
|||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import json
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from open_webui.socket.main import get_event_emitter
|
||||
from open_webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
|
|
@ -74,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
|
|||
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_user_id(
|
||||
user_id: str,
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_admin_user),
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
):
|
||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Chats.get_chat_list_by_user_id(
|
||||
user_id, include_archived=True, skip=skip, limit=limit
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -192,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
|
|||
############################
|
||||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatResponse])
|
||||
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
|
@ -265,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
|||
|
||||
@router.get("/archived", response_model=list[ChatTitleIdResponse])
|
||||
async def get_archived_session_user_chat_list(
|
||||
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
chat_list = [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_archived_chat_list_by_user_id(
|
||||
user.id,
|
||||
filter=filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
]
|
||||
|
||||
return chat_list
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -372,6 +419,107 @@ async def update_chat_by_id(
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChatMessageById
|
||||
############################
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
id,
|
||||
message_id,
|
||||
{
|
||||
"content": form_data.content,
|
||||
},
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
if event_emitter:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message",
|
||||
"data": {
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
"content": form_data.content,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
||||
|
||||
############################
|
||||
# SendChatMessageEventById
|
||||
############################
|
||||
class EventForm(BaseModel):
|
||||
type: str
|
||||
data: dict
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
||||
async def send_chat_message_event_by_id(
|
||||
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if event_emitter:
|
||||
await event_emitter(form_data.model_dump())
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
############################
|
||||
# DeleteChatById
|
||||
############################
|
||||
|
|
@ -476,7 +624,12 @@ async def clone_chat_by_id(
|
|||
|
||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
else:
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
|
|
@ -530,8 +683,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if not has_permission(
|
||||
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
|||
from open_webui.config import get_config, save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -66,10 +68,86 @@ async def set_direct_connections_config(
|
|||
}
|
||||
|
||||
|
||||
############################
|
||||
# ToolServers Config
|
||||
############################
|
||||
|
||||
|
||||
class ToolServerConnection(BaseModel):
|
||||
url: str
|
||||
path: str
|
||||
auth_type: Optional[str]
|
||||
key: Optional[str]
|
||||
config: Optional[dict]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolServersConfigForm(BaseModel):
|
||||
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
|
||||
|
||||
|
||||
@router.get("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def set_tool_servers_config(
|
||||
request: Request,
|
||||
form_data: ToolServersConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers/verify")
|
||||
async def verify_tool_servers_config(
|
||||
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Verify the connection to the tool server.
|
||||
"""
|
||||
try:
|
||||
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
|
||||
url = f"{form_data.url}/{form_data.path}"
|
||||
return await get_tool_server_data(token, url)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to the tool server: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
ENABLE_CODE_EXECUTION: bool
|
||||
CODE_EXECUTION_ENGINE: str
|
||||
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
|
|
@ -77,11 +155,19 @@ class CodeInterpreterConfigForm(BaseModel):
|
|||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||
|
||||
|
||||
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
|
|
@ -89,13 +175,34 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
|
|||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_interpreter_config(
|
||||
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_execution_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
|
||||
|
||||
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
|
|
@ -116,8 +223,18 @@ async def set_code_interpreter_config(
|
|||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
|
|
@ -125,6 +242,7 @@ async def set_code_interpreter_config(
|
|||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -56,19 +56,35 @@ async def update_config(
|
|||
}
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FeedbackUserResponse(FeedbackResponse):
|
||||
user: Optional[UserModel] = None
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
|
||||
|
||||
feedback_list = []
|
||||
for feedback in feedbacks:
|
||||
user = Users.get_user_by_id(feedback.user_id)
|
||||
feedback_list.append(
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(),
|
||||
user=UserResponse(**user.model_dump()) if user else None,
|
||||
)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedback_list
|
||||
|
||||
|
||||
@router.delete("/feedbacks/all")
|
||||
|
|
@ -80,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
|||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackModel(
|
||||
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
||||
|
|
|
|||
|
|
@ -1,21 +1,39 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
Query,
|
||||
)
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
||||
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -26,6 +44,39 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# Check if the current user has access to a file through any knowledge bases the user may be in.
|
||||
############################
|
||||
|
||||
|
||||
def has_access_to_file(
|
||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
||||
) -> bool:
|
||||
file = Files.get_file_by_id(file_id)
|
||||
log.debug(f"Checking if user has {access_type} access to file")
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
has_access = False
|
||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||
|
||||
if knowledge_base_id:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||
user.id, access_type
|
||||
)
|
||||
for knowledge_base in knowledge_bases:
|
||||
if knowledge_base.id == knowledge_base_id:
|
||||
has_access = True
|
||||
break
|
||||
|
||||
return has_access
|
||||
|
||||
|
||||
############################
|
||||
# Upload File
|
||||
############################
|
||||
|
|
@ -35,19 +86,55 @@ router = APIRouter()
|
|||
def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
|
||||
)
|
||||
file_metadata = metadata if metadata else {}
|
||||
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
file_extension = os.path.splitext(filename)[1]
|
||||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
|
||||
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(
|
||||
f"File type {file_extension} is not allowed"
|
||||
),
|
||||
)
|
||||
|
||||
# replace filename with uuid
|
||||
id = str(uuid.uuid4())
|
||||
name = filename
|
||||
filename = f"{id}_{filename}"
|
||||
contents, file_path = Storage.upload_file(file.file, filename)
|
||||
tags = {
|
||||
"OpenWebUI-User-Email": user.email,
|
||||
"OpenWebUI-User-Id": user.id,
|
||||
"OpenWebUI-User-Name": user.name,
|
||||
"OpenWebUI-File-Id": id,
|
||||
}
|
||||
contents, file_path = Storage.upload_file(file.file, filename, tags)
|
||||
|
||||
file_item = Files.insert_new_file(
|
||||
user.id,
|
||||
|
|
@ -65,19 +152,40 @@ def upload_file(
|
|||
}
|
||||
),
|
||||
)
|
||||
if process:
|
||||
try:
|
||||
if file.content_type:
|
||||
if file.content_type.startswith("audio/") or file.content_type in {
|
||||
"video/webm"
|
||||
}:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
file_item = FileModelResponse(
|
||||
**{
|
||||
**file_item.model_dump(),
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
}
|
||||
)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
file_item = FileModelResponse(
|
||||
**{
|
||||
**file_item.model_dump(),
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if file_item:
|
||||
return file_item
|
||||
|
|
@ -91,7 +199,7 @@ def upload_file(
|
|||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -101,14 +209,62 @@ def upload_file(
|
|||
|
||||
|
||||
@router.get("/", response_model=list[FileModelResponse])
|
||||
async def list_files(user=Depends(get_verified_user)):
|
||||
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
|
||||
if not content:
|
||||
for file in files:
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return files
|
||||
|
||||
|
||||
############################
|
||||
# Search Files
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/search", response_model=list[FileModelResponse])
|
||||
async def search_files(
|
||||
filename: str = Query(
|
||||
...,
|
||||
description="Filename pattern to search for. Supports wildcards such as '*.txt'",
|
||||
),
|
||||
content: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""
|
||||
Search for files by filename with support for wildcard patterns.
|
||||
"""
|
||||
# Get files according to user role
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
|
||||
# Get matching files
|
||||
matching_files = [
|
||||
file for file in files if fnmatch(file.filename.lower(), filename.lower())
|
||||
]
|
||||
|
||||
if not matching_files:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No files found matching the pattern.",
|
||||
)
|
||||
|
||||
if not content:
|
||||
for file in matching_files:
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return matching_files
|
||||
|
||||
|
||||
############################
|
||||
# Delete All Files
|
||||
############################
|
||||
|
|
@ -144,7 +300,17 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
return file
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -162,7 +328,17 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
return {"content": file.data.get("content", "")}
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -186,7 +362,17 @@ async def update_file_data_content_by_id(
|
|||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
):
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
|
|
@ -212,9 +398,22 @@ async def update_file_data_content_by_id(
|
|||
|
||||
|
||||
@router.get("/{id}/content")
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
file_path = Path(file_path)
|
||||
|
|
@ -225,17 +424,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
content_type = file.meta.get("content_type")
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename)
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
if attachment:
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
else:
|
||||
if content_type == "application/pdf" or filename.lower().endswith(
|
||||
".pdf"
|
||||
):
|
||||
headers["Content-Disposition"] = (
|
||||
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
content_type = "application/pdf"
|
||||
elif content_type != "text/plain":
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -259,14 +470,32 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
@router.get("/{id}/content/html")
|
||||
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
file_user = Users.get_user_by_id(file.user_id)
|
||||
if not file_user.role == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
log.info(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -291,7 +520,17 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
file_path = file.path
|
||||
|
||||
# Handle Unicode filenames
|
||||
|
|
@ -342,7 +581,18 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
@router.delete("/{id}")
|
||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
):
|
||||
# We should add Chroma cleanup here
|
||||
|
||||
result = Files.delete_file_by_id(id)
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ from open_webui.env import SRC_LOG_LEVELS
|
|||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_permission
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -228,7 +230,19 @@ async def update_folder_is_expanded_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def delete_folder_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat_delete_permission = has_permission(
|
||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user.role != "admin" and not chat_delete_permission:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -8,11 +12,22 @@ from open_webui.models.functions import (
|
|||
FunctionResponse,
|
||||
Functions,
|
||||
)
|
||||
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
replace_imports,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -36,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
|
|||
return Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_function_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
function_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the function"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": function_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# SyncFunctions
|
||||
############################
|
||||
|
||||
|
||||
class SyncFunctionsForm(FunctionForm):
|
||||
functions: list[FunctionModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=Optional[FunctionModel])
|
||||
async def sync_functions(
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewFunction
|
||||
############################
|
||||
|
|
@ -68,7 +174,7 @@ async def create_new_function(
|
|||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
|
||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function:
|
||||
|
|
@ -79,7 +185,7 @@ async def create_new_function(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to create a new function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -183,7 +289,7 @@ async def update_function_by_id(
|
|||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
|
|
@ -256,11 +362,9 @@ async def get_function_valves_spec_by_id(
|
|||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
|
|
@ -284,11 +388,9 @@ async def update_function_valves_by_id(
|
|||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
|
|
@ -299,7 +401,7 @@ async def update_function_valves_by_id(
|
|||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -347,11 +449,9 @@ async def get_function_user_valves_spec_by_id(
|
|||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
|
|
@ -371,11 +471,9 @@ async def update_function_user_valves_by_id(
|
|||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
|
|
@ -388,7 +486,7 @@ async def update_function_user_valves_by_id(
|
|||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
|
|||
16
backend/open_webui/routers/groups.py
Normal file → Executable file
16
backend/open_webui/routers/groups.py
Normal file → Executable file
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
|
|
@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
|||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
|
|
@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new group: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -94,7 +100,7 @@ async def update_group_by_id(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error deleting group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
|
@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
|||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
|||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
|
||||
class GeminiConfigForm(BaseModel):
|
||||
GEMINI_API_BASE_URL: str
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
|
||||
class ConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
engine: str
|
||||
|
|
@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
|
|||
openai: OpenAIConfigForm
|
||||
automatic1111: Automatic1111ConfigForm
|
||||
comfyui: ComfyUIConfigForm
|
||||
gemini: GeminiConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
|
|
@ -103,6 +113,11 @@ async def update_config(
|
|||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||
form_data.gemini.GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
|
|
@ -129,6 +144,8 @@ async def update_config(
|
|||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
|
|
@ -155,6 +172,10 @@ async def update_config(
|
|||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
|
|
@ -224,6 +253,12 @@ def get_image_model(request):
|
|||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "imagen-3.0-generate-002"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
|
|
@ -298,6 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
|
|
@ -322,7 +362,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
if model_node_id:
|
||||
model_list_key = None
|
||||
|
||||
print(workflow[model_node_id]["class_type"])
|
||||
log.info(workflow[model_node_id]["class_type"])
|
||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||
"required"
|
||||
]:
|
||||
|
|
@ -411,7 +451,7 @@ def load_url_image_data(url, headers=None):
|
|||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
def upload_image(request, image_data, content_type, metadata, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
|
|
@ -420,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
|
|||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
|
@ -461,7 +501,11 @@ async def image_generations(
|
|||
if form_data.size
|
||||
else request.app.state.config.IMAGE_SIZE
|
||||
),
|
||||
"response_format": "b64_json",
|
||||
**(
|
||||
{}
|
||||
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else {"response_format": "b64_json"}
|
||||
),
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
|
|
@ -478,11 +522,50 @@ async def image_generations(
|
|||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
if image_url := image.get("url", None):
|
||||
image_data, content_type = load_url_image_data(image_url, headers)
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
model = get_image_model(request)
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
|
|
@ -529,9 +612,9 @@ async def image_generations(
|
|||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
|
@ -541,7 +624,7 @@ async def image_generations(
|
|||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
||||
):
|
||||
if form_data.model:
|
||||
set_image_model(form_data.model)
|
||||
set_image_model(request, form_data.model)
|
||||
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
|
|
@ -582,9 +665,9 @@ async def image_generations(
|
|||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
{**data, "info": res["info"]},
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from open_webui.models.knowledge import (
|
|||
KnowledgeResponse,
|
||||
KnowledgeUserResponse,
|
||||
)
|
||||
from open_webui.models.files import Files, FileModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.models.files import Files, FileModel, FileMetadataResponse
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
ProcessFileForm,
|
||||
|
|
@ -161,13 +161,94 @@ async def create_new_knowledge(
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# ReindexKnowledgeFiles
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/reindex", response_model=bool)
|
||||
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
|
||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
deleted_knowledge_bases = []
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
# -- Robust error handling for missing or invalid data
|
||||
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
|
||||
log.warning(
|
||||
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
|
||||
)
|
||||
try:
|
||||
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
|
||||
deleted_knowledge_bases.append(knowledge_base.id)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
file_ids = knowledge_base.data.get("file_ids", [])
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||
VECTOR_DB_CLIENT.delete_collection(
|
||||
collection_name=knowledge_base.id
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
|
||||
continue # Skip, don't raise
|
||||
|
||||
failed_files = []
|
||||
for file in files:
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file.id, collection_name=knowledge_base.id
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
|
||||
)
|
||||
failed_files.append({"file_id": file.id, "error": str(e)})
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
|
||||
# Don't raise, just continue
|
||||
continue
|
||||
|
||||
if failed_files:
|
||||
log.warning(
|
||||
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
|
||||
)
|
||||
for failed in failed_files:
|
||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||
|
||||
log.info(
|
||||
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
############################
|
||||
# GetKnowledgeById
|
||||
############################
|
||||
|
||||
|
||||
class KnowledgeFilesResponse(KnowledgeResponse):
|
||||
files: list[FileModel]
|
||||
files: list[FileMetadataResponse]
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||
|
|
@ -183,7 +264,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
):
|
||||
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
|
|
@ -311,7 +392,7 @@ def add_file_to_knowledge_by_id(
|
|||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
|
|
@ -388,7 +469,7 @@ def update_file_from_knowledge_by_id(
|
|||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
|
|
@ -437,14 +518,24 @@ def remove_file_from_knowledge_by_id(
|
|||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
try:
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
try:
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
|
@ -460,7 +551,7 @@ def remove_file_from_knowledge_by_id(
|
|||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
|
|
@ -614,7 +705,7 @@ def add_files_to_knowledge_batch(
|
|||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
|
|
@ -656,7 +747,7 @@ def add_files_to_knowledge_batch(
|
|||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_files_by_ids(existing_file_ids),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
|
|
@ -664,5 +755,6 @@ def add_files_to_knowledge_batch(
|
|||
)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from typing import Optional
|
||||
|
||||
from open_webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
|
@ -57,7 +57,9 @@ async def add_memory(
|
|||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
|
|
@ -82,7 +84,7 @@ async def query_memory(
|
|||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
|
|
@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
|
|||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
|
@ -149,7 +153,9 @@ async def update_memory_by_id(
|
|||
form_data: MemoryUpdateModel,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
memory = Memories.update_memory_by_id(memory_id, form_data.content)
|
||||
memory = Memories.update_memory_by_id_and_user_id(
|
||||
memory_id, user.id, form_data.content
|
||||
)
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
|
|
@ -161,7 +167,7 @@ async def update_memory_by_id(
|
|||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
|
|
|
|||
218
backend/open_webui/routers/notes.py
Normal file
218
backend/open_webui/routers/notes.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetNotes
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NoteUserResponse])
|
||||
async def get_notes(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
notes = [
|
||||
NoteUserResponse(
|
||||
**{
|
||||
**note.model_dump(),
|
||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
||||
}
|
||||
)
|
||||
for note in Notes.get_notes_by_user_id(user.id, "write")
|
||||
]
|
||||
|
||||
return notes
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[NoteUserResponse])
|
||||
async def get_note_list(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
notes = [
|
||||
NoteUserResponse(
|
||||
**{
|
||||
**note.model_dump(),
|
||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
||||
}
|
||||
)
|
||||
for note in Notes.get_notes_by_user_id(user.id, "read")
|
||||
]
|
||||
|
||||
return notes
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewNote
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[NoteModel])
|
||||
async def create_new_note(
|
||||
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.insert_new_note(form_data, user.id)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[NoteModel])
|
||||
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
return note
|
||||
|
||||
|
||||
############################
|
||||
# UpdateNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[NoteModel])
|
||||
async def update_note_by_id(
|
||||
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.update_note_by_id(id, form_data)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.delete_note_by_id(id)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -9,11 +9,18 @@ import os
|
|||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
import requests
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
|
|
@ -26,7 +33,7 @@ from fastapi import (
|
|||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
|
|
@ -49,8 +56,9 @@ from open_webui.config import (
|
|||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
|
@ -66,12 +74,27 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
|||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
|
|
@ -96,6 +119,7 @@ async def send_post_request(
|
|||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
|
|
@ -110,7 +134,18 @@ async def send_post_request(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -186,12 +221,26 @@ async def verify_connection(
|
|||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/api/version",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
|
|
@ -253,8 +302,24 @@ async def update_config(
|
|||
}
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request):
|
||||
def merge_ollama_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
|
|
@ -262,7 +327,7 @@ async def get_all_models(request: Request):
|
|||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
|
|
@ -275,7 +340,9 @@ async def get_all_models(request: Request):
|
|||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/tags", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
|
|
@ -291,7 +358,10 @@ async def get_all_models(request: Request):
|
|||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "local")
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
model_ids = api_config.get("model_ids", [])
|
||||
|
||||
if len(model_ids) != 0 and "models" in response:
|
||||
|
|
@ -302,27 +372,18 @@ async def get_all_models(request: Request):
|
|||
)
|
||||
)
|
||||
|
||||
if prefix_id:
|
||||
for model in response.get("models", []):
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
|
|
@ -330,6 +391,22 @@ async def get_all_models(request: Request):
|
|||
)
|
||||
}
|
||||
|
||||
try:
|
||||
loaded_models = await get_ollama_loaded_models(request, user=user)
|
||||
expires_map = {
|
||||
m["name"]: m["expires_at"]
|
||||
for m in loaded_models["models"]
|
||||
if "expires_at" in m
|
||||
}
|
||||
|
||||
for m in models["models"]:
|
||||
if m["name"] in expires_map:
|
||||
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
||||
dt = datetime.fromisoformat(expires_map[m["name"]])
|
||||
m["expires_at"] = int(dt.timestamp())
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to get loaded models: {e}")
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
|
|
@ -360,7 +437,7 @@ async def get_ollama_tags(
|
|||
models = []
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
|
@ -370,7 +447,19 @@ async def get_ollama_tags(
|
|||
r = requests.request(
|
||||
method="GET",
|
||||
url=f"{url}/api/tags",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -398,24 +487,95 @@ async def get_ollama_tags(
|
|||
return models
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/ps", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
models = {
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
)
|
||||
)
|
||||
}
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
if url_idx is None:
|
||||
# returns lowest version
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/version",
|
||||
request_tasks = []
|
||||
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/api/version",
|
||||
key,
|
||||
)
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
responses = list(filter(lambda x: x is not None, responses))
|
||||
|
||||
|
|
@ -462,35 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
|||
return {"version": False}
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/ps",
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/api/unload")
|
||||
async def unload_model(
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
model_name = form_data.name
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing 'name' of model to unload."
|
||||
)
|
||||
|
||||
# Refresh/load models if needed, get mapping from name to URLs
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
# Canonicalize model name (if not supplied with version)
|
||||
if ":" not in model_name:
|
||||
model_name = f"{model_name}:latest"
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
|
||||
)
|
||||
url_indices = models[model_name]["urls"]
|
||||
|
||||
# Send unload to ALL url_indices
|
||||
results = []
|
||||
errors = []
|
||||
for idx in url_indices:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
)
|
||||
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id and model_name.startswith(f"{prefix_id}."):
|
||||
model_name = model_name[len(f"{prefix_id}.") :]
|
||||
|
||||
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
|
||||
|
||||
try:
|
||||
res = await send_post_request(
|
||||
url=f"{url}/api/generate",
|
||||
payload=json.dumps(payload),
|
||||
stream=False,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
results.append({"url_idx": idx, "success": True, "response": res})
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to unload model on node {idx}: {e}")
|
||||
errors.append({"url_idx": idx, "success": False, "error": str(e)})
|
||||
|
||||
if len(errors) > 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@router.post("/api/pull")
|
||||
@router.post("/api/pull/{url_idx}")
|
||||
async def pull_model(
|
||||
|
|
@ -509,6 +708,7 @@ async def pull_model(
|
|||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -527,7 +727,7 @@ async def push_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
|
|
@ -545,6 +745,7 @@ async def push_model(
|
|||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -571,6 +772,7 @@ async def create_model(
|
|||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -588,7 +790,7 @@ async def copy_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.source in models:
|
||||
|
|
@ -609,6 +811,16 @@ async def copy_model(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -643,7 +855,7 @@ async def delete_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
|
|
@ -665,6 +877,16 @@ async def delete_model(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
|
@ -693,7 +915,7 @@ async def delete_model(
|
|||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name not in models:
|
||||
|
|
@ -714,6 +936,16 @@ async def show_model_info(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -757,7 +989,7 @@ async def embed(
|
|||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -774,8 +1006,16 @@ async def embed(
|
|||
)
|
||||
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
|
|
@ -783,6 +1023,16 @@ async def embed(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -826,7 +1076,7 @@ async def embeddings(
|
|||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -843,8 +1093,16 @@ async def embeddings(
|
|||
)
|
||||
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
|
|
@ -852,6 +1110,16 @@ async def embeddings(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -882,7 +1150,7 @@ class GenerateCompletionForm(BaseModel):
|
|||
prompt: str
|
||||
suffix: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
format: Optional[str] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
system: Optional[str] = None
|
||||
template: Optional[str] = None
|
||||
|
|
@ -901,7 +1169,7 @@ async def generate_completion(
|
|||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -931,20 +1199,34 @@ async def generate_completion(
|
|||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
@validator("content", pre=True)
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
format: Optional[dict] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
|
|
@ -1001,13 +1283,14 @@ async def generate_chat_completion(
|
|||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
if payload.get("options") is None:
|
||||
payload["options"] = {}
|
||||
system = params.pop("system", None)
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
params, (payload.get("options", {}) or {})
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -1040,13 +1323,14 @@ async def generate_chat_completion(
|
|||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# payload["keep_alive"] = -1 # keep alive forever
|
||||
return await send_post_request(
|
||||
url=f"{url}/api/chat",
|
||||
payload=json.dumps(payload),
|
||||
stream=form_data.stream,
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1058,7 +1342,7 @@ class OpenAIChatMessageContent(BaseModel):
|
|||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, list[OpenAIChatMessageContent]]
|
||||
content: Union[Optional[str], list[OpenAIChatMessageContent]]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
|
@ -1149,6 +1433,7 @@ async def generate_openai_completion(
|
|||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1187,8 +1472,10 @@ async def generate_openai_chat_completion(
|
|||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
|
|
@ -1227,6 +1514,7 @@ async def generate_openai_chat_completion(
|
|||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1240,7 +1528,7 @@ async def get_openai_models(
|
|||
|
||||
models = []
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models(request)
|
||||
model_list = await get_all_models(request, user=user)
|
||||
models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
|
|
@ -1341,7 +1629,9 @@ async def download_file_stream(
|
|||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(file_url, headers=headers) as response:
|
||||
async with session.get(
|
||||
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
|
|
@ -1356,7 +1646,8 @@ async def download_file_stream(
|
|||
|
||||
if done:
|
||||
file.seek(0)
|
||||
hashed = calculate_sha256(file)
|
||||
chunk_size = 1024 * 1024 * 2
|
||||
hashed = calculate_sha256(file, chunk_size)
|
||||
file.seek(0)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
|
|
@ -1420,7 +1711,9 @@ async def upload_model(
|
|||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
|
||||
filename = os.path.basename(file.filename)
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
# --- P1: save file locally ---
|
||||
|
|
@ -1465,13 +1758,13 @@ async def upload_model(
|
|||
os.remove(file_path)
|
||||
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
model_name, ext = os.path.splitext(filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
"files": {filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
|
|
@ -1488,7 +1781,7 @@ async def upload_model(
|
|||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"name": filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
|
|
|
|||
|
|
@ -21,11 +21,13 @@ from open_webui.config import (
|
|||
CACHE_DIR,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
|
|
@ -35,6 +37,9 @@ from open_webui.utils.payload import (
|
|||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
|
@ -51,12 +56,26 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
|||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
|
|
@ -75,18 +94,23 @@ async def cleanup_response(
|
|||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_o3_handler(payload):
|
||||
def openai_o_series_handler(payload):
|
||||
"""
|
||||
Handle o1, o3 specific parameters
|
||||
Handle "o" series specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
||||
# Handle system role conversion based on model type
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
payload["messages"][0]["role"] = "user"
|
||||
model_lower = payload["model"].lower()
|
||||
# Legacy models use "user" role instead of "system"
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
payload["messages"][0]["role"] = "developer"
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -172,7 +196,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
|
|
@ -247,7 +271,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
|
|
@ -271,7 +295,9 @@ async def get_all_models_responses(request: Request) -> list:
|
|||
):
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -291,6 +317,7 @@ async def get_all_models_responses(request: Request) -> list:
|
|||
send_get_request(
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -326,14 +353,22 @@ async def get_all_models_responses(request: Request) -> list:
|
|||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "external")
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
|
||||
if prefix_id:
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
if prefix_id:
|
||||
model["id"] = f"{prefix_id}.{model['id']}"
|
||||
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
return responses
|
||||
|
||||
|
|
@ -351,14 +386,14 @@ async def get_filtered_models(models, user):
|
|||
return filtered_models
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request) -> dict[str, list]:
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user=user)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
|
|
@ -373,6 +408,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
|||
|
||||
for idx, models in enumerate(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
|
||||
merged_list.extend(
|
||||
[
|
||||
{
|
||||
|
|
@ -380,21 +416,25 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
|||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"connection_type": model.get("connection_type", "external"),
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
if (model.get("id") or model.get("name"))
|
||||
and (
|
||||
"api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -418,65 +458,79 @@ async def get_models(
|
|||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
if api_config.get("azure", False):
|
||||
models = {
|
||||
"data": api_config.get("model_ids", []) or [],
|
||||
"object": "list",
|
||||
}
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
models = response_data
|
||||
response_data = await r.json()
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
models = response_data
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
log.exception(f"Client error: {str(e)}")
|
||||
|
|
@ -498,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
|
|||
url: str
|
||||
key: str
|
||||
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
async def verify_connection(
|
||||
|
|
@ -506,27 +562,64 @@ async def verify_connection(
|
|||
url = form_data.url
|
||||
key = form_data.key
|
||||
|
||||
api_config = form_data.config or {}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
|
||||
async with session.get(
|
||||
url=f"{url}/openai/models?api-version={api_version}",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
|
|
@ -540,6 +633,63 @@ async def verify_connection(
|
|||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
|
||||
def convert_to_azure_payload(
|
||||
url,
|
||||
payload: dict,
|
||||
):
|
||||
model = payload.get("model", "")
|
||||
|
||||
# Filter allowed parameters based on Azure OpenAI API
|
||||
allowed_params = {
|
||||
"messages",
|
||||
"temperature",
|
||||
"role",
|
||||
"content",
|
||||
"contentPart",
|
||||
"contentPartImage",
|
||||
"enhancements",
|
||||
"dataSources",
|
||||
"n",
|
||||
"stream",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"log_probs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"max_completion_tokens",
|
||||
}
|
||||
|
||||
# Special handling for o-series models
|
||||
if model.startswith("o") and model.endswith("-mini"):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Remove temperature if not 1 for o-series models
|
||||
if "temperature" in payload and payload["temperature"] != 1:
|
||||
log.debug(
|
||||
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
|
||||
)
|
||||
del payload["temperature"]
|
||||
|
||||
# Filter out unsupported parameters
|
||||
payload = {k: v for k, v in payload.items() if k in allowed_params}
|
||||
|
||||
url = f"{url}/openai/deployments/{model}"
|
||||
return url, payload
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
|
|
@ -565,8 +715,12 @@ async def generate_chat_completion(
|
|||
model_id = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -587,7 +741,7 @@ async def generate_chat_completion(
|
|||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
|
|
@ -621,10 +775,10 @@ async def generate_chat_completion(
|
|||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
# Check if model is from "o" series
|
||||
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
|
||||
if is_o_series:
|
||||
payload = openai_o_series_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
|
|
@ -635,6 +789,43 @@ async def generate_chat_completion(
|
|||
del payload["max_tokens"]
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
if "logit_bias" in payload:
|
||||
payload["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
request_url, payload = convert_to_azure_payload(url, payload)
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = api_version
|
||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||
else:
|
||||
request_url = f"{url}/chat/completions"
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
payload = json.dumps(payload)
|
||||
|
||||
r = None
|
||||
|
|
@ -649,30 +840,10 @@ async def generate_chat_completion(
|
|||
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/chat/completions",
|
||||
url=request_url,
|
||||
data=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
# Check if response is SSE
|
||||
|
|
@ -801,31 +972,54 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
idx = 0
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = (
|
||||
api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
)
|
||||
|
||||
payload = json.loads(body)
|
||||
url, payload = convert_to_azure_payload(url, payload)
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
request_url = f"{url}/{path}"
|
||||
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=f"{url}/{path}",
|
||||
url=request_url,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -851,7 +1045,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from fastapi import (
|
|||
status,
|
||||
APIRouter,
|
||||
)
|
||||
import aiohttp
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
|
|
@ -17,7 +18,7 @@ from pydantic import BaseModel
|
|||
from starlette.responses import FileResponse
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
|
@ -56,96 +57,111 @@ def get_sorted_filters(model_id, models):
|
|||
return sorted_filters
|
||||
|
||||
|
||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
res = r.json()
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
res = (
|
||||
await response.json()
|
||||
if response.content_type == "application/json"
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
raise Exception(response.status, res["detail"])
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
r = requests.post(
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
payload = data
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
try:
|
||||
res = r.json()
|
||||
res = (
|
||||
await response.json()
|
||||
if "application/json" in response.content_type
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
return Exception(r.status_code, res)
|
||||
raise Exception(response.status, res)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -161,7 +177,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
|
||||
urlIdxs = [
|
||||
|
|
@ -188,9 +204,11 @@ async def upload_pipeline(
|
|||
file: UploadFile = File(...),
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
filename = os.path.basename(file.filename)
|
||||
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
if not (filename and filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
|
|
@ -198,7 +216,7 @@ async def upload_pipeline(
|
|||
|
||||
upload_folder = f"{CACHE_DIR}/pipelines"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
file_path = os.path.join(upload_folder, file.filename)
|
||||
file_path = os.path.join(upload_folder, filename)
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
|
@ -223,7 +241,7 @@ async def upload_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
|
@ -274,7 +292,7 @@ async def add_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -319,7 +337,7 @@ async def delete_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -353,7 +371,7 @@ async def get_pipelines(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -392,7 +410,7 @@ async def get_pipeline_valves(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -432,7 +450,7 @@ async def get_pipeline_valves_spec(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -474,7 +492,7 @@ async def update_pipeline_valves(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -20,6 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
|||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
|
|
@ -182,35 +183,28 @@ async def generate_title(
|
|||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
messages = form_data["messages"]
|
||||
|
||||
# Remove reasoning details from the messages
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
messages,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
max_tokens = (
|
||||
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 1000}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
{"max_tokens": max_tokens}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 1000,
|
||||
"max_completion_tokens": max_tokens,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
|
|
@ -221,6 +215,12 @@ async def generate_title(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -290,6 +290,12 @@ async def generate_chat_tags(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -356,6 +362,12 @@ async def generate_image_prompt(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -433,6 +445,12 @@ async def generate_queries(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -514,6 +532,12 @@ async def generate_autocompletion(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -571,7 +595,7 @@ async def generate_emoji(
|
|||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
|
|
@ -584,6 +608,12 @@ async def generate_emoji(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -613,17 +643,6 @@ async def generate_moa_response(
|
|||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = moa_response_generation_template(
|
||||
|
|
@ -633,7 +652,7 @@ async def generate_moa_response(
|
|||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
|
|
@ -644,6 +663,12 @@ async def generate_moa_response(
|
|||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
import re
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
|
|
@ -8,13 +13,20 @@ from open_webui.models.tools import (
|
|||
ToolUserResponse,
|
||||
Tools,
|
||||
)
|
||||
from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.tools import get_tool_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -25,11 +37,51 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read")
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.TOOL_SERVERS:
|
||||
# If the tool servers are not set, we need to set them
|
||||
# This is done only once when the server starts
|
||||
# This is done to avoid loading the tool servers every time
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"name": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
"meta": {
|
||||
"description": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("description", ""),
|
||||
},
|
||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
||||
server["idx"]
|
||||
]
|
||||
.get("config", {})
|
||||
.get("access_control", None),
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if user.role != "admin":
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control)
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
|
|
@ -47,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
|
|||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_tool_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
tool_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the tool"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": tool_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# ExportTools
|
||||
############################
|
||||
|
|
@ -89,18 +216,18 @@ async def create_new_tools(
|
|||
if tools is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tools_module, frontmatter = load_tools_module_by_id(
|
||||
tool_module, frontmatter = load_tool_module_by_id(
|
||||
form_data.id, content=form_data.content
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[form_data.id] = tools_module
|
||||
TOOLS[form_data.id] = tool_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
specs = get_tool_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if tools:
|
||||
|
|
@ -111,7 +238,7 @@ async def create_new_tools(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
@ -178,22 +305,20 @@ async def update_tools_by_id(
|
|||
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tools_module, frontmatter = load_tools_module_by_id(
|
||||
id, content=form_data.content
|
||||
)
|
||||
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[id] = tools_module
|
||||
TOOLS[id] = tool_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[id])
|
||||
specs = get_tool_specs(TOOLS[id])
|
||||
|
||||
updated = {
|
||||
**form_data.model_dump(exclude={"id"}),
|
||||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
|
|
@ -284,7 +409,7 @@ async def get_tools_valves_spec_by_id(
|
|||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "Valves"):
|
||||
|
|
@ -327,7 +452,7 @@ async def update_tools_valves_by_id(
|
|||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if not hasattr(tools_module, "Valves"):
|
||||
|
|
@ -343,7 +468,7 @@ async def update_tools_valves_by_id(
|
|||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
@ -383,7 +508,7 @@ async def get_tools_user_valves_spec_by_id(
|
|||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
|
|
@ -407,7 +532,7 @@ async def update_tools_user_valves_by_id(
|
|||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
|
|
@ -421,7 +546,7 @@ async def update_tools_user_valves_by_id(
|
|||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ import logging
|
|||
from typing import Optional
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
UserListResponse,
|
||||
UserRoleUpdateForm,
|
||||
Users,
|
||||
UserSettings,
|
||||
|
|
@ -17,7 +19,10 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
||||
from open_webui.utils.access_control import get_permissions, has_permission
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -29,13 +34,38 @@ router = APIRouter()
|
|||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserModel])
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users(skip, limit)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@router.get("/all", response_model=UserListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -45,7 +75,7 @@ async def get_users(
|
|||
|
||||
@router.get("/groups")
|
||||
async def get_user_groups(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -54,8 +84,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissisions(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
return user_permissions
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -68,32 +102,52 @@ class WorkspacePermissions(BaseModel):
|
|||
tools: bool = False
|
||||
|
||||
|
||||
class SharingPermissions(BaseModel):
|
||||
public_models: bool = True
|
||||
public_knowledge: bool = True
|
||||
public_prompts: bool = True
|
||||
public_tools: bool = True
|
||||
|
||||
|
||||
class ChatPermissions(BaseModel):
|
||||
controls: bool = True
|
||||
file_upload: bool = True
|
||||
delete: bool = True
|
||||
edit: bool = True
|
||||
share: bool = True
|
||||
export: bool = True
|
||||
stt: bool = True
|
||||
tts: bool = True
|
||||
call: bool = True
|
||||
multiple_models: bool = True
|
||||
temporary: bool = True
|
||||
temporary_enforced: bool = False
|
||||
|
||||
|
||||
class FeaturesPermissions(BaseModel):
|
||||
direct_tool_servers: bool = False
|
||||
web_search: bool = True
|
||||
image_generation: bool = True
|
||||
code_interpreter: bool = True
|
||||
notes: bool = True
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
workspace: WorkspacePermissions
|
||||
sharing: SharingPermissions
|
||||
chat: ChatPermissions
|
||||
features: FeaturesPermissions
|
||||
|
||||
|
||||
@router.get("/default/permissions", response_model=UserPermissions)
|
||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"workspace": WorkspacePermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("workspace", {})
|
||||
),
|
||||
"sharing": SharingPermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("sharing", {})
|
||||
),
|
||||
"chat": ChatPermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("chat", {})
|
||||
),
|
||||
|
|
@ -104,7 +158,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.post("/default/permissions")
|
||||
async def update_user_permissions(
|
||||
async def update_default_user_permissions(
|
||||
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
|
||||
|
|
@ -151,9 +205,22 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/user/settings/update", response_model=UserSettings)
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||
updated_user_settings = form_data.model_dump()
|
||||
if (
|
||||
user.role != "admin"
|
||||
and "toolServers" in updated_user_settings.get("ui").keys()
|
||||
and not has_permission(
|
||||
user.id,
|
||||
"features.direct_tool_servers",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
)
|
||||
):
|
||||
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
||||
updated_user_settings["ui"].pop("toolServers", None)
|
||||
|
||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
@ -263,6 +330,21 @@ async def update_user_by_id(
|
|||
form_data: UserUpdateForm,
|
||||
session_user=Depends(get_admin_user),
|
||||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id and session_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
|
|
@ -310,6 +392,21 @@ async def update_user_by_id(
|
|||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
# Prevent deletion of the primary admin user
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
|
||||
|
|
@ -321,6 +418,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
|||
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
|
||||
)
|
||||
|
||||
# Prevent self-deletion
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
|
|
|
|||
|
|
@ -1,48 +1,84 @@
|
|||
import black
|
||||
import logging
|
||||
import markdown
|
||||
|
||||
from open_webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
class CodeForm(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
return {"code": form_data.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/code/execute")
|
||||
async def execute_code(
|
||||
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
form_data.code,
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Code execution engine not supported",
|
||||
)
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
|
|
@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
|||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
|
|
@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
|||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating PDF: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,23 @@ import socketio
|
|||
import logging
|
||||
import sys
|
||||
import time
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
from open_webui.models.channels import Channels
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.utils.redis import (
|
||||
get_sentinels_from_env,
|
||||
get_sentinel_url_from_env,
|
||||
)
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
WEBSOCKET_SENTINEL_PORT,
|
||||
WEBSOCKET_SENTINEL_HOSTS,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict, RedisLock
|
||||
|
|
@ -28,7 +36,14 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
|||
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
if WEBSOCKET_SENTINEL_HOSTS:
|
||||
mgr = socketio.AsyncRedisManager(
|
||||
get_sentinel_url_from_env(
|
||||
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
)
|
||||
)
|
||||
else:
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
|
|
@ -54,14 +69,30 @@ TIMEOUT_DURATION = 3
|
|||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
log.debug("Using Redis to manage websockets.")
|
||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
redis_sentinels = get_sentinels_from_env(
|
||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
)
|
||||
SESSION_POOL = RedisDict(
|
||||
"open-webui:session_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
)
|
||||
USER_POOL = RedisDict(
|
||||
"open-webui:user_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
)
|
||||
USAGE_POOL = RedisDict(
|
||||
"open-webui:usage_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
)
|
||||
|
||||
clean_up_lock = RedisLock(
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=TIMEOUT_DURATION * 2,
|
||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
redis_sentinels=redis_sentinels,
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
|
|
@ -128,18 +159,19 @@ def get_models_in_use():
|
|||
|
||||
@sio.on("usage")
|
||||
async def usage(sid, data):
|
||||
model_id = data["model"]
|
||||
# Record the timestamp for the last update
|
||||
current_time = int(time.time())
|
||||
if sid in SESSION_POOL:
|
||||
model_id = data["model"]
|
||||
# Record the timestamp for the last update
|
||||
current_time = int(time.time())
|
||||
|
||||
# Store the new usage data and task
|
||||
USAGE_POOL[model_id] = {
|
||||
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
|
||||
sid: {"updated_at": current_time},
|
||||
}
|
||||
# Store the new usage data and task
|
||||
USAGE_POOL[model_id] = {
|
||||
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
|
||||
sid: {"updated_at": current_time},
|
||||
}
|
||||
|
||||
# Broadcast the usage data to all clients
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
# Broadcast the usage data to all clients
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
@sio.event
|
||||
|
|
@ -247,7 +279,8 @@ async def channel_events(sid, data):
|
|||
|
||||
@sio.on("user-list")
|
||||
async def user_list(sid):
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
if sid in SESSION_POOL:
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
|
||||
|
||||
@sio.event
|
||||
|
|
@ -268,15 +301,23 @@ async def disconnect(sid):
|
|||
# print(f"Unknown session ID {sid} disconnected")
|
||||
|
||||
|
||||
def get_event_emitter(request_info):
|
||||
def get_event_emitter(request_info, update_db=True):
|
||||
async def __event_emitter__(event_data):
|
||||
user_id = request_info["user_id"]
|
||||
|
||||
session_ids = list(
|
||||
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
|
||||
set(
|
||||
USER_POOL.get(user_id, [])
|
||||
+ (
|
||||
[request_info.get("session_id")]
|
||||
if request_info.get("session_id")
|
||||
else []
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for session_id in session_ids:
|
||||
await sio.emit(
|
||||
emit_tasks = [
|
||||
sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
|
|
@ -285,41 +326,47 @@ def get_event_emitter(request_info):
|
|||
},
|
||||
to=session_id,
|
||||
)
|
||||
for session_id in session_ids
|
||||
]
|
||||
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
event_data.get("data", {}),
|
||||
)
|
||||
await asyncio.gather(*emit_tasks)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "message":
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
)
|
||||
if update_db:
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
event_data.get("data", {}),
|
||||
)
|
||||
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
if "type" in event_data and event_data["type"] == "message":
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
)
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
if message:
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
return __event_emitter__
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
import json
|
||||
import redis
|
||||
import uuid
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs):
|
||||
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
)
|
||||
|
||||
def aquire_lock(self):
|
||||
# nx=True will only set this key if it _hasn't_ already been set
|
||||
|
|
@ -31,9 +33,11 @@ class RedisLock:
|
|||
|
||||
|
||||
class RedisDict:
|
||||
def __init__(self, name, redis_url):
|
||||
def __init__(self, name, redis_url, redis_sentinels=[]):
|
||||
self.name = name
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
serialized_value = json.dumps(value)
|
||||
|
|
|
|||
BIN
backend/open_webui/static/apple-touch-icon.png
Normal file
BIN
backend/open_webui/static/apple-touch-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.3 KiB |
|
|
@ -269,11 +269,6 @@ tbody + tbody {
|
|||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
|
||||
.markdown-section p + ul {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* List item styles */
|
||||
.markdown-section li {
|
||||
padding: 2px;
|
||||
|
|
|
|||
BIN
backend/open_webui/static/favicon-96x96.png
Normal file
BIN
backend/open_webui/static/favicon-96x96.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.7 KiB |
BIN
backend/open_webui/static/favicon-dark.png
Normal file
BIN
backend/open_webui/static/favicon-dark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue