feat: merge with main
80
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
|
@ -1,80 +0,0 @@
|
||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help us improve
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
---
|
|
||||||
|
|
||||||
# Bug Report
|
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
|
||||||
|
|
||||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
|
||||||
|
|
||||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
|
||||||
|
|
||||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
|
||||||
|
|
||||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Installation Method
|
|
||||||
|
|
||||||
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
|
|
||||||
|
|
||||||
## Environment
|
|
||||||
|
|
||||||
- **Open WebUI Version:** [e.g., v0.3.11]
|
|
||||||
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
|
|
||||||
|
|
||||||
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
|
|
||||||
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
|
|
||||||
|
|
||||||
**Confirmation:**
|
|
||||||
|
|
||||||
- [ ] I have read and followed all the instructions provided in the README.md.
|
|
||||||
- [ ] I am on the latest version of both Open WebUI and Ollama.
|
|
||||||
- [ ] I have included the browser console logs.
|
|
||||||
- [ ] I have included the Docker container logs.
|
|
||||||
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
|
|
||||||
|
|
||||||
## Expected Behavior:
|
|
||||||
|
|
||||||
[Describe what you expected to happen.]
|
|
||||||
|
|
||||||
## Actual Behavior:
|
|
||||||
|
|
||||||
[Describe what actually happened.]
|
|
||||||
|
|
||||||
## Description
|
|
||||||
|
|
||||||
**Bug Summary:**
|
|
||||||
[Provide a brief but clear summary of the bug]
|
|
||||||
|
|
||||||
## Reproduction Details
|
|
||||||
|
|
||||||
**Steps to Reproduce:**
|
|
||||||
[Outline the steps to reproduce the bug. Be as detailed as possible.]
|
|
||||||
|
|
||||||
## Logs and Screenshots
|
|
||||||
|
|
||||||
**Browser Console Logs:**
|
|
||||||
[Include relevant browser console logs, if applicable]
|
|
||||||
|
|
||||||
**Docker Container Logs:**
|
|
||||||
[Include relevant Docker container logs, if applicable]
|
|
||||||
|
|
||||||
**Screenshots/Screen Recordings (if applicable):**
|
|
||||||
[Attach any relevant screenshots to help illustrate the issue]
|
|
||||||
|
|
||||||
## Additional Information
|
|
||||||
|
|
||||||
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
|
|
||||||
|
|
||||||
## Note
|
|
||||||
|
|
||||||
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!
|
|
||||||
144
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
name: Bug Report
|
||||||
|
description: Create a detailed bug report to help us improve Open WebUI.
|
||||||
|
title: 'issue: '
|
||||||
|
labels: ['bug', 'triage']
|
||||||
|
assignees: []
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
# Bug Report
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
|
||||||
|
|
||||||
|
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
|
||||||
|
|
||||||
|
- **Contributing**: If you encounter an issue, consider submitting a pull request or forking the project. We prioritize preventing contributor burnout to maintain Open WebUI's quality.
|
||||||
|
|
||||||
|
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: issue-check
|
||||||
|
attributes:
|
||||||
|
label: Check Existing Issues
|
||||||
|
description: Confirm that you’ve checked for existing reports before submitting a new one.
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and discussions.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: installation-method
|
||||||
|
attributes:
|
||||||
|
label: Installation Method
|
||||||
|
description: How did you install Open WebUI?
|
||||||
|
options:
|
||||||
|
- Git Clone
|
||||||
|
- Pip Install
|
||||||
|
- Docker
|
||||||
|
- Other
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: open-webui-version
|
||||||
|
attributes:
|
||||||
|
label: Open WebUI Version
|
||||||
|
description: Specify the version (e.g., v0.3.11)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: ollama-version
|
||||||
|
attributes:
|
||||||
|
label: Ollama Version (if applicable)
|
||||||
|
description: Specify the version (e.g., v0.2.0, or v0.1.32-rc1)
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: operating-system
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: browser
|
||||||
|
attributes:
|
||||||
|
label: Browser (if applicable)
|
||||||
|
description: Specify the browser/version (e.g., Chrome 100.0, Firefox 98.0)
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: confirmation
|
||||||
|
attributes:
|
||||||
|
label: Confirmation
|
||||||
|
description: Ensure the following prerequisites have been met.
|
||||||
|
options:
|
||||||
|
- label: I have read and followed all instructions in `README.md`.
|
||||||
|
required: true
|
||||||
|
- label: I am using the latest version of **both** Open WebUI and Ollama.
|
||||||
|
required: true
|
||||||
|
- label: I have checked the browser console logs.
|
||||||
|
required: true
|
||||||
|
- label: I have checked the Docker container logs.
|
||||||
|
required: true
|
||||||
|
- label: I have listed steps to reproduce the bug in detail.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: expected-behavior
|
||||||
|
attributes:
|
||||||
|
label: Expected Behavior
|
||||||
|
description: Describe what should have happened.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: actual-behavior
|
||||||
|
attributes:
|
||||||
|
label: Actual Behavior
|
||||||
|
description: Describe what actually happened.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduction-steps
|
||||||
|
attributes:
|
||||||
|
label: Steps to Reproduce
|
||||||
|
description: Provide step-by-step instructions to reproduce the issue.
|
||||||
|
placeholder: |
|
||||||
|
1. Go to '...'
|
||||||
|
2. Click on '...'
|
||||||
|
3. Scroll down to '...'
|
||||||
|
4. See the error message '...'
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: logs-screenshots
|
||||||
|
attributes:
|
||||||
|
label: Logs & Screenshots
|
||||||
|
description: Include relevant logs, errors, or screenshots to help diagnose the issue.
|
||||||
|
placeholder: 'Attach logs from the browser console, Docker logs, or error messages.'
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: additional-info
|
||||||
|
attributes:
|
||||||
|
label: Additional Information
|
||||||
|
description: Provide any extra details that may assist in understanding the issue.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
## Note
|
||||||
|
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
|
||||||
|
Thank you for contributing to Open WebUI!
|
||||||
35
.github/ISSUE_TEMPLATE/feature_request.md
vendored
|
|
@ -1,35 +0,0 @@
|
||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
---
|
|
||||||
|
|
||||||
# Feature Request
|
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
|
||||||
|
|
||||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
|
||||||
|
|
||||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
|
||||||
|
|
||||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
|
||||||
|
|
||||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
64
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
name: Feature Request
|
||||||
|
description: Suggest an idea for this project
|
||||||
|
title: 'feat: '
|
||||||
|
labels: ['triage']
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
## Important Notes
|
||||||
|
### Before submitting
|
||||||
|
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
|
||||||
|
It's likely we're already tracking it! If you’re unsure, start a discussion post first.
|
||||||
|
This will help us efficiently focus on improving the project.
|
||||||
|
|
||||||
|
### Collaborate respectfully
|
||||||
|
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
|
||||||
|
|
||||||
|
Remember:
|
||||||
|
- Open WebUI is a **volunteer-driven project**
|
||||||
|
- It's managed by a **single maintainer**
|
||||||
|
- It's supported by contributors who also have **full-time jobs**
|
||||||
|
|
||||||
|
We appreciate your time and ask that you **respect ours**.
|
||||||
|
|
||||||
|
|
||||||
|
### Contributing
|
||||||
|
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||||
|
|
||||||
|
### Bug reproducibility
|
||||||
|
If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a `pip install` with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "[issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: existing-issue
|
||||||
|
attributes:
|
||||||
|
label: Check Existing Issues
|
||||||
|
description: Please confirm that you've checked for existing similar requests
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and discussions.
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: problem-description
|
||||||
|
attributes:
|
||||||
|
label: Problem Description
|
||||||
|
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
|
||||||
|
placeholder: "Ex. I'm always frustrated when..."
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: solution-description
|
||||||
|
attributes:
|
||||||
|
label: Desired Solution you'd like
|
||||||
|
description: Clearly describe what you want to happen.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: alternatives-considered
|
||||||
|
attributes:
|
||||||
|
label: Alternatives Considered
|
||||||
|
description: A clear and concise description of any alternative solutions or features you've considered.
|
||||||
|
- type: textarea
|
||||||
|
id: additional-context
|
||||||
|
attributes:
|
||||||
|
label: Additional Context
|
||||||
|
description: Add any other context or screenshots about the feature request here.
|
||||||
6
.github/workflows/docker-build.yaml
vendored
|
|
@ -14,7 +14,7 @@ env:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-main-image:
|
build-main-image:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
packages: write
|
packages: write
|
||||||
|
|
@ -111,7 +111,7 @@ jobs:
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
|
||||||
build-cuda-image:
|
build-cuda-image:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
packages: write
|
packages: write
|
||||||
|
|
@ -211,7 +211,7 @@ jobs:
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
|
||||||
build-ollama-image:
|
build-ollama-image:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
packages: write
|
packages: write
|
||||||
|
|
|
||||||
2
.github/workflows/release-pypi.yml
vendored
|
|
@ -19,7 +19,7 @@ jobs:
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- uses: actions/setup-node@v4
|
- uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
node-version: 22
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.11
|
python-version: 3.11
|
||||||
|
|
|
||||||
179
CHANGELOG.md
|
|
@ -5,6 +5,185 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.5.19] - 2024-03-04
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **📊 Logit Bias Parameter Support**: Fine-tune conversation dynamics by adjusting the Logit Bias parameter directly in chat settings, giving you more control over model responses.
|
||||||
|
- **⌨️ Customizable Enter Behavior**: You can now configure Enter to send messages only when combined with Ctrl (Ctrl+Enter) via Settings > Interface, preventing accidental message sends.
|
||||||
|
- **📝 Collapsible Code Blocks**: Easily collapse long code blocks to declutter your chat, making it easier to focus on important details.
|
||||||
|
- **🏷️ Tag Selector in Model Selector**: Quickly find and categorize models with the new tag filtering system in the Model Selector, streamlining model discovery.
|
||||||
|
- **📈 Experimental Elasticsearch Vector DB Support**: Now supports Elasticsearch as a vector database, offering more flexibility for data retrieval in Retrieval-Augmented Generation (RAG) workflows.
|
||||||
|
- **⚙️ General Reliability Enhancements**: Various stability improvements across the WebUI, ensuring a smoother, more consistent experience.
|
||||||
|
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔄 "Stream" Hook Activation**: Fixed an issue where the "Stream" hook only worked when globally enabled, ensuring reliable real-time filtering.
|
||||||
|
- **📧 LDAP Email Case Sensitivity**: Resolved an issue where LDAP login failed due to email case sensitivity mismatches, improving authentication reliability.
|
||||||
|
- **💬 WebSocket Chat Event Registration**: Fixed a bug preventing chat event listeners from being registered upon sign-in, ensuring real-time updates work properly.
|
||||||
|
|
||||||
|
## [0.5.18] - 2025-02-27
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity.
|
||||||
|
- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility.
|
||||||
|
- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing.
|
||||||
|
|
||||||
|
## [0.5.17] - 2025-02-27
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking.
|
||||||
|
- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter).
|
||||||
|
- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options.
|
||||||
|
- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient.
|
||||||
|
- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience.
|
||||||
|
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors.
|
||||||
|
- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data.
|
||||||
|
- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI.
|
||||||
|
|
||||||
|
## [0.5.16] - 2025-02-20
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again.
|
||||||
|
|
||||||
|
## [0.5.15] - 2025-02-20
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding.
|
||||||
|
- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval.
|
||||||
|
- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data.
|
||||||
|
- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI.
|
||||||
|
- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions.
|
||||||
|
- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs.
|
||||||
|
- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations.
|
||||||
|
- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI.
|
||||||
|
- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue.
|
||||||
|
- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow.
|
||||||
|
|
||||||
|
## [0.5.14] - 2025-02-17
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability.
|
||||||
|
|
||||||
|
## [0.5.13] - 2025-02-17
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results.
|
||||||
|
- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval.
|
||||||
|
- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts.
|
||||||
|
- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat.
|
||||||
|
- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history.
|
||||||
|
- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data.
|
||||||
|
- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity.
|
||||||
|
- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages.
|
||||||
|
|
||||||
|
## [0.5.12] - 2025-02-13
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **🛠️ Multiple Tool Calls Support for Native Function Mode**: Functions now can call multiple tools within a single response, unlocking better automation and workflow flexibility when using native function calling.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **📝 Playground Text Completion Restored**: Addressed an issue where text completion in the Playground was not functioning.
|
||||||
|
- **🔗 Direct Connections Now Work for Regular Users**: Fixed a bug where users with the 'user' role couldn't establish direct API connections, enabling seamless model usage for all user tiers.
|
||||||
|
- **⚡ Landing Page Input No Longer Lags with Long Text**: Improved input responsiveness on the landing page, ensuring fast and smooth typing experiences even when entering long messages.
|
||||||
|
- **🔧 Parameter in Functions Fixed**: Fixed an issue where the reserved parameters wasn’t recognized within functions, restoring full functionality for advanced task-based automation.
|
||||||
|
|
||||||
|
## [0.5.11] - 2025-02-13
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **🎤 Kokoro-JS TTS Support**: A new on-device, high-quality text-to-speech engine has been integrated, vastly improving voice generation quality—everything runs directly in your browser.
|
||||||
|
- **🐍 Jupyter Notebook Support in Code Interpreter**: Now, you can configure Code Interpreter to run Python code not only via Pyodide but also through Jupyter, offering a more robust coding environment for AI-driven computations and analysis.
|
||||||
|
- **🔗 Direct API Connections for Private & Local Inference**: You can now connect Open WebUI to your private or localhost API inference endpoints. CORS must be enabled, but this unlocks direct, on-device AI infrastructure support.
|
||||||
|
- **🔍 Advanced Domain Filtering for Web Search**: You can now specify which domains should be included or excluded from web searches, refining results for more relevant information retrieval.
|
||||||
|
- **🚀 Improved Image Generation Metadata Handling**: Generated images now retain metadata for better organization and future retrieval.
|
||||||
|
- **📂 S3 Key Prefix Support**: Fine-grained control over S3 storage file structuring with configurable key prefixes.
|
||||||
|
- **📸 Support for Image-Only Messages**: Send messages containing only images, facilitating more visual-centric interactions.
|
||||||
|
- **🌍 Updated Translations**: German, Spanish, Traditional Chinese, and Catalan translations updated for better multilingual support.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔧 OAuth Debug Logs & Username Claim Fixes**: Debug logs have been added for OAuth role and group management, with fixes ensuring proper OAuth username retrieval and claim handling.
|
||||||
|
- **📌 Citations Formatting & Toggle Fixes**: Inline citation toggles now function correctly, and citations with more than three sources are now fully visible when expanded.
|
||||||
|
- **📸 ComfyUI Maximum Seed Value Constraint Fixed**: The maximum allowed seed value for ComfyUI has been corrected, preventing unintended behavior.
|
||||||
|
- **🔑 Connection Settings Stability**: Addressed connection settings issues that were causing instability when saving configurations.
|
||||||
|
- **📂 GGUF Model Upload Stability**: Fixed upload inconsistencies for GGUF models, ensuring reliable local model handling.
|
||||||
|
- **🔧 Web Search Configuration Bug**: Fixed issues where web search filters and settings weren't correctly applied.
|
||||||
|
- **💾 User Settings Persistence Fix**: Ensured user-specific settings are correctly saved and applied across sessions.
|
||||||
|
- **🔄 OpenID Username Retrieval Enhancement**: Usernames are now correctly picked up and assigned for OpenID Connect (OIDC) logins.
|
||||||
|
|
||||||
|
## [0.5.10] - 2025-02-05
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **⚙️ System Prompts Now Properly Templated via API**: Resolved an issue where system prompts were not being correctly processed when used through the API, ensuring template variables now function as expected.
|
||||||
|
- **📝 '<thinking>' Tag Display Issue Fixed**: Fixed a bug where the 'thinking' tag was disrupting content rendering, ensuring clean and accurate text display.
|
||||||
|
- **💻 Code Interpreter Stability with Custom Functions**: Addressed failures when using the Code Interpreter with certain custom functions like Anthropic, ensuring smoother execution and better compatibility.
|
||||||
|
|
||||||
|
## [0.5.9] - 2025-02-05
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **💡 "Think" Tag Display Issue**: Resolved a bug where the "Think" tag was not functioning correctly, ensuring proper visualization of the model's reasoning process before delivering responses.
|
||||||
|
|
||||||
|
## [0.5.8] - 2025-02-05
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **🖥️ Code Interpreter**: Models can now execute code in real time to refine their answers dynamically, running securely within a sandboxed browser environment using Pyodide. Perfect for calculations, data analysis, and AI-assisted coding tasks!
|
||||||
|
- **💬 Redesigned Chat Input UI**: Enjoy a sleeker and more intuitive message input with improved feature selection, making it easier than ever to toggle tools, enable search, and interact with AI seamlessly.
|
||||||
|
- **🛠️ Native Tool Calling Support (Experimental)**: Supported models can now call tools natively, reducing query latency and improving contextual responses. More enhancements coming soon!
|
||||||
|
- **🔗 Exa Search Engine Integration**: A new search provider has been added, allowing users to retrieve up-to-date and relevant information without leaving the chat interface.
|
||||||
|
- **🌍 Localized Dates & Times**: Date and time formats now match your system locale, ensuring a more natural, region-specific experience.
|
||||||
|
- **📎 User Headers for External Embedding APIs**: API calls to external embedding services now include user-related headers.
|
||||||
|
- **🌍 "Always On" Web Search Toggle**: A new option under Settings > Interface allows users to enable Web Search by default—transform Open WebUI into your go-to search engine, ensuring AI-powered results with every query.
|
||||||
|
- **🚀 General Performance & Stability**: Significant improvements across the platform for a faster, more reliable experience.
|
||||||
|
- **🖼️ UI/UX Enhancements**: Numerous design refinements improving readability, responsiveness, and accessibility.
|
||||||
|
- **🌍 Improved Translations**: Chinese, Korean, French, Ukrainian and Serbian translations have been updated with refined terminologies for better clarity.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **🔄 OAuth Name Field Fallback**: Resolves OAuth login failures by using the email field as a fallback when a name is missing.
|
||||||
|
- **🔑 Google Drive Credentials Restriction**: Ensures only authenticated users can access Google Drive credentials for enhanced security.
|
||||||
|
- **🌐 DuckDuckGo Search Rate Limit Handling**: Fixes issues where users would encounter 202 errors due to rate limits when using DuckDuckGo for web search.
|
||||||
|
- **📁 File Upload Permission Indicator**: Users are now notified when they lack permission to upload files, improving clarity on system restrictions.
|
||||||
|
- **🔧 Max Tokens Issue**: Fixes cases where 'max_tokens' were not applied correctly, ensuring proper model behavior.
|
||||||
|
- **🔍 Validation for RAG Web Search URLs**: Filters out invalid or unsupported URLs when using web-based retrieval augmentation.
|
||||||
|
- **🖋️ Title Generation Bug**: Fixes inconsistencies in title generation, ensuring proper chat organization.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- **⚡ Deprecated Non-Web Worker Pyodide Execution**: Moves entirely to browser sandboxing for better performance and security.
|
||||||
|
|
||||||
## [0.5.7] - 2025-01-23
|
## [0.5.7] - 2025-01-23
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
11
README.md
|
|
@ -27,10 +27,15 @@ git push origin main
|
||||||
|
|
||||||
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
||||||
|
|
||||||
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
|
||||||
|
>
|
||||||
|
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
|
||||||
|
|
||||||
|
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
||||||
|
|
||||||
## Key Features of Open WebUI ⭐
|
## Key Features of Open WebUI ⭐
|
||||||
|
|
||||||
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
|
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
|
||||||
|
|
@ -188,7 +193,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
|
||||||
|
|
||||||
In the last part of the command, replace `open-webui` with your container name if it is different.
|
In the last part of the command, replace `open-webui` with your container name if it is different.
|
||||||
|
|
||||||
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
|
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
|
||||||
|
|
||||||
### Using the Dev Branch 🌙
|
### Using the Dev Branch 🌙
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,13 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import base64
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, Optional, TypeVar
|
from typing import Generic, Optional, TypeVar
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||||
|
|
@ -42,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
||||||
|
|
||||||
# Function to run the alembic migrations
|
# Function to run the alembic migrations
|
||||||
def run_migrations():
|
def run_migrations():
|
||||||
print("Running migrations")
|
log.info("Running migrations")
|
||||||
try:
|
try:
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
|
|
@ -55,7 +56,7 @@ def run_migrations():
|
||||||
|
|
||||||
command.upgrade(alembic_cfg, "head")
|
command.upgrade(alembic_cfg, "head")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
log.exception(f"Error running migrations: {e}")
|
||||||
|
|
||||||
|
|
||||||
run_migrations()
|
run_migrations()
|
||||||
|
|
@ -586,6 +587,14 @@ load_oauth_providers()
|
||||||
|
|
||||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
||||||
|
|
||||||
|
for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
target_path = STATIC_DIR / file_path.relative_to(
|
||||||
|
(FRONTEND_BUILD_DIR / "static")
|
||||||
|
)
|
||||||
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copyfile(file_path, target_path)
|
||||||
|
|
||||||
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
|
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
|
||||||
|
|
||||||
if frontend_favicon.exists():
|
if frontend_favicon.exists():
|
||||||
|
|
@ -593,8 +602,6 @@ if frontend_favicon.exists():
|
||||||
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"An error occurred: {e}")
|
logging.error(f"An error occurred: {e}")
|
||||||
else:
|
|
||||||
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
|
||||||
|
|
||||||
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
|
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
|
||||||
|
|
||||||
|
|
@ -603,12 +610,18 @@ if frontend_splash.exists():
|
||||||
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
|
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"An error occurred: {e}")
|
logging.error(f"An error occurred: {e}")
|
||||||
else:
|
|
||||||
logging.warning(f"Frontend splash not found at {frontend_splash}")
|
frontend_loader = FRONTEND_BUILD_DIR / "static" / "loader.js"
|
||||||
|
|
||||||
|
if frontend_loader.exists():
|
||||||
|
try:
|
||||||
|
shutil.copyfile(frontend_loader, STATIC_DIR / "loader.js")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# CUSTOM_NAME
|
# CUSTOM_NAME (Legacy)
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
|
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
|
||||||
|
|
@ -650,6 +663,12 @@ if CUSTOM_NAME:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# LICENSE_KEY
|
||||||
|
####################################
|
||||||
|
|
||||||
|
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# STORAGE PROVIDER
|
# STORAGE PROVIDER
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -660,27 +679,47 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
|
||||||
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
||||||
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
|
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
|
||||||
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
|
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
|
||||||
|
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
|
||||||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
|
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
|
||||||
|
S3_USE_ACCELERATE_ENDPOINT = (
|
||||||
|
os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None)
|
||||||
|
|
||||||
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
|
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
|
||||||
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
|
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
|
||||||
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
|
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None)
|
||||||
|
AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None)
|
||||||
|
AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# File Upload DIR
|
# File Upload DIR
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
UPLOAD_DIR = f"{DATA_DIR}/uploads"
|
UPLOAD_DIR = DATA_DIR / "uploads"
|
||||||
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Cache DIR
|
# Cache DIR
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
CACHE_DIR = DATA_DIR / "cache"
|
||||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# DIRECT CONNECTIONS
|
||||||
|
####################################
|
||||||
|
|
||||||
|
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
|
||||||
|
"ENABLE_DIRECT_CONNECTIONS",
|
||||||
|
"direct.enable",
|
||||||
|
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# OLLAMA_BASE_URL
|
# OLLAMA_BASE_URL
|
||||||
|
|
@ -755,6 +794,9 @@ ENABLE_OPENAI_API = PersistentConfig(
|
||||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||||
|
|
||||||
|
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||||
|
GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
|
||||||
|
|
||||||
|
|
||||||
if OPENAI_API_BASE_URL == "":
|
if OPENAI_API_BASE_URL == "":
|
||||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
@ -927,6 +969,12 @@ USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = (
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
|
||||||
|
os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_USER_PERMISSIONS = {
|
DEFAULT_USER_PERMISSIONS = {
|
||||||
"workspace": {
|
"workspace": {
|
||||||
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
|
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
|
||||||
|
|
@ -944,6 +992,7 @@ DEFAULT_USER_PERMISSIONS = {
|
||||||
"features": {
|
"features": {
|
||||||
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
||||||
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
||||||
|
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1052,7 +1101,7 @@ try:
|
||||||
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
|
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
|
||||||
banners = [BannerModel(**banner) for banner in banners]
|
banners = [BannerModel(**banner) for banner in banners]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading WEBUI_BANNERS: {e}")
|
log.exception(f"Error loading WEBUI_BANNERS: {e}")
|
||||||
banners = []
|
banners = []
|
||||||
|
|
||||||
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
|
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
|
||||||
|
|
@ -1094,21 +1143,27 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||||
|
Generate a concise, 3-5 word title with an emoji summarizing the chat history.
|
||||||
Examples of titles:
|
### Guidelines:
|
||||||
📉 Stock Market Trends
|
- The title should clearly represent the main theme or subject of the conversation.
|
||||||
🍪 Perfect Chocolate Chip Recipe
|
- Use emojis that enhance understanding of the topic, but avoid quotation marks or special formatting.
|
||||||
Evolution of Music Streaming
|
- Write the title in the chat's primary language; default to English if multilingual.
|
||||||
Remote Work Productivity Tips
|
- Prioritize accuracy over excessive creativity; keep it clear and simple.
|
||||||
Artificial Intelligence in Healthcare
|
### Output:
|
||||||
🎮 Video Game Development Insights
|
JSON format: { "title": "your concise title here" }
|
||||||
|
### Examples:
|
||||||
|
- { "title": "📉 Stock Market Trends" },
|
||||||
|
- { "title": "🍪 Perfect Chocolate Chip Recipe" },
|
||||||
|
- { "title": "Evolution of Music Streaming" },
|
||||||
|
- { "title": "Remote Work Productivity Tips" },
|
||||||
|
- { "title": "Artificial Intelligence in Healthcare" },
|
||||||
|
- { "title": "🎮 Video Game Development Insights" }
|
||||||
|
### Chat History:
|
||||||
<chat_history>
|
<chat_history>
|
||||||
{{MESSAGES:END:2}}
|
{{MESSAGES:END:2}}
|
||||||
</chat_history>"""
|
</chat_history>"""
|
||||||
|
|
||||||
|
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
"TAGS_GENERATION_PROMPT_TEMPLATE",
|
"TAGS_GENERATION_PROMPT_TEMPLATE",
|
||||||
"task.tags.prompt_template",
|
"task.tags.prompt_template",
|
||||||
|
|
@ -1165,6 +1220,12 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||||
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_TITLE_GENERATION = PersistentConfig(
|
||||||
|
"ENABLE_TITLE_GENERATION",
|
||||||
|
"task.title.enable",
|
||||||
|
os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
||||||
"ENABLE_SEARCH_QUERY_GENERATION",
|
"ENABLE_SEARCH_QUERY_GENERATION",
|
||||||
|
|
@ -1277,7 +1338,28 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
|
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}
|
||||||
|
|
||||||
|
Your task is to choose and return the correct tool(s) from the list of available tools based on the query. Follow these guidelines:
|
||||||
|
|
||||||
|
- Return only the JSON object, without any additional text or explanation.
|
||||||
|
|
||||||
|
- If no tools match the query, return an empty array:
|
||||||
|
{
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
|
||||||
|
- If one or more tools match the query, construct a JSON response containing a "tool_calls" array with objects that include:
|
||||||
|
- "name": The tool's name.
|
||||||
|
- "parameters": A dictionary of required parameters and their corresponding values.
|
||||||
|
|
||||||
|
The format for the JSON response is strictly:
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{"name": "toolName1", "parameters": {"key1": "value1"}},
|
||||||
|
{"name": "toolName2", "parameters": {"key2": "value2"}}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||||
|
|
@ -1290,6 +1372,131 @@ Your task is to synthesize these responses into a single, high-quality response.
|
||||||
|
|
||||||
Responses from models: {{responses}}"""
|
Responses from models: {{responses}}"""
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Code Interpreter
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
CODE_EXECUTION_ENGINE = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_ENGINE",
|
||||||
|
"code_execution.engine",
|
||||||
|
os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_JUPYTER_URL = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_JUPYTER_URL",
|
||||||
|
"code_execution.jupyter.url",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH",
|
||||||
|
"code_execution.jupyter.auth",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN",
|
||||||
|
"code_execution.jupyter.auth_token",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD",
|
||||||
|
"code_execution.jupyter.auth_password",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig(
|
||||||
|
"CODE_EXECUTION_JUPYTER_TIMEOUT",
|
||||||
|
"code_execution.jupyter.timeout",
|
||||||
|
int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")),
|
||||||
|
)
|
||||||
|
|
||||||
|
ENABLE_CODE_INTERPRETER = PersistentConfig(
|
||||||
|
"ENABLE_CODE_INTERPRETER",
|
||||||
|
"code_interpreter.enable",
|
||||||
|
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_ENGINE = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_ENGINE",
|
||||||
|
"code_interpreter.engine",
|
||||||
|
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_PROMPT_TEMPLATE",
|
||||||
|
"code_interpreter.prompt_template",
|
||||||
|
os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_URL",
|
||||||
|
"code_interpreter.jupyter.url",
|
||||||
|
os.environ.get(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||||
|
"code_interpreter.jupyter.auth",
|
||||||
|
os.environ.get(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||||
|
"code_interpreter.jupyter.auth_token",
|
||||||
|
os.environ.get(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||||
|
"code_interpreter.jupyter.auth_password",
|
||||||
|
os.environ.get(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||||
|
"code_interpreter.jupyter.timeout",
|
||||||
|
int(
|
||||||
|
os.environ.get(
|
||||||
|
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||||
|
os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||||
|
#### Tools Available
|
||||||
|
|
||||||
|
1. **Code Interpreter**: `<code_interpreter type="code" lang="python"></code_interpreter>`
|
||||||
|
- You have access to a Python shell that runs directly in the user's browser, enabling fast execution of code for analysis, calculations, or problem-solving. Use it in this response.
|
||||||
|
- The Python code you write can incorporate a wide array of libraries, handle data manipulation or visualization, perform API calls for web-related tasks, or tackle virtually any computational challenge. Use this flexibility to **think outside the box, craft elegant solutions, and harness Python's full potential**.
|
||||||
|
- To use it, **you must enclose your code within `<code_interpreter type="code" lang="python">` XML tags** and stop right away. If you don't, the code won't execute. Do NOT use triple backticks.
|
||||||
|
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
|
||||||
|
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
|
||||||
|
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
|
||||||
|
- **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
|
||||||
|
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
|
||||||
|
|
||||||
|
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Vector Database
|
# Vector Database
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -1298,27 +1505,34 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
|
||||||
|
|
||||||
# Chroma
|
# Chroma
|
||||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
|
||||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
if VECTOR_DB == "chroma":
|
||||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
import chromadb
|
||||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
|
||||||
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||||
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
|
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||||
# Comma-separated list of header=value pairs
|
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||||
if CHROMA_HTTP_HEADERS:
|
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
||||||
|
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get(
|
||||||
|
"CHROMA_CLIENT_AUTH_CREDENTIALS", ""
|
||||||
|
)
|
||||||
|
# Comma-separated list of header=value pairs
|
||||||
|
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||||
|
if CHROMA_HTTP_HEADERS:
|
||||||
CHROMA_HTTP_HEADERS = dict(
|
CHROMA_HTTP_HEADERS = dict(
|
||||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
CHROMA_HTTP_HEADERS = None
|
CHROMA_HTTP_HEADERS = None
|
||||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||||
|
|
||||||
# Milvus
|
# Milvus
|
||||||
|
|
||||||
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
||||||
MILVUS_DB = os.environ.get("MILVUS_DB", "default")
|
MILVUS_DB = os.environ.get("MILVUS_DB", "default")
|
||||||
|
MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
|
||||||
|
|
||||||
# Qdrant
|
# Qdrant
|
||||||
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
||||||
|
|
@ -1331,6 +1545,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
||||||
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
||||||
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
||||||
|
|
||||||
|
# ElasticSearch
|
||||||
|
ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200")
|
||||||
|
ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None)
|
||||||
|
ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None)
|
||||||
|
ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
|
||||||
|
ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
|
||||||
|
ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
|
||||||
|
SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
|
||||||
|
|
||||||
# Pgvector
|
# Pgvector
|
||||||
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
||||||
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
||||||
|
|
@ -1365,6 +1588,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig(
|
||||||
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
|
||||||
|
"ENABLE_ONEDRIVE_INTEGRATION",
|
||||||
|
"onedrive.enable",
|
||||||
|
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
ONEDRIVE_CLIENT_ID = PersistentConfig(
|
||||||
|
"ONEDRIVE_CLIENT_ID",
|
||||||
|
"onedrive.client_id",
|
||||||
|
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
|
||||||
|
)
|
||||||
|
|
||||||
# RAG Content Extraction
|
# RAG Content Extraction
|
||||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||||
"CONTENT_EXTRACTION_ENGINE",
|
"CONTENT_EXTRACTION_ENGINE",
|
||||||
|
|
@ -1384,6 +1619,26 @@ DOCLING_SERVER_URL = PersistentConfig(
|
||||||
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||||
|
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||||
|
"rag.document_intelligence_endpoint",
|
||||||
|
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
|
||||||
|
"DOCUMENT_INTELLIGENCE_KEY",
|
||||||
|
"rag.document_intelligence_key",
|
||||||
|
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||||
|
"BYPASS_EMBEDDING_AND_RETRIEVAL",
|
||||||
|
"rag.bypass_embedding_and_retrieval",
|
||||||
|
os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RAG_TOP_K = PersistentConfig(
|
RAG_TOP_K = PersistentConfig(
|
||||||
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
||||||
)
|
)
|
||||||
|
|
@ -1399,6 +1654,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
||||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAG_FULL_CONTEXT = PersistentConfig(
|
||||||
|
"RAG_FULL_CONTEXT",
|
||||||
|
"rag.full_context",
|
||||||
|
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
RAG_FILE_MAX_COUNT = PersistentConfig(
|
RAG_FILE_MAX_COUNT = PersistentConfig(
|
||||||
"RAG_FILE_MAX_COUNT",
|
"RAG_FILE_MAX_COUNT",
|
||||||
"rag.file.max_count",
|
"rag.file.max_count",
|
||||||
|
|
@ -1513,7 +1774,7 @@ Respond to the user query using the provided context, incorporating inline citat
|
||||||
- Respond in the same language as the user's query.
|
- Respond in the same language as the user's query.
|
||||||
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
||||||
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
|
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
|
||||||
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
|
- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `<source_id>` tag is explicitly provided in the context.**
|
||||||
- Do not cite if the <source_id> tag is not provided in the context.
|
- Do not cite if the <source_id> tag is not provided in the context.
|
||||||
- Do not use XML tags in your response.
|
- Do not use XML tags in your response.
|
||||||
- Ensure citations are concise and directly related to the information provided.
|
- Ensure citations are concise and directly related to the information provided.
|
||||||
|
|
@ -1594,11 +1855,17 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||||
|
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL",
|
||||||
|
"rag.web.search.bypass_embedding_and_retrieval",
|
||||||
|
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
# You can provide a list of your own websites to filter after performing a web search.
|
# You can provide a list of your own websites to filter after performing a web search.
|
||||||
# This ensures the highest level of safety and reliability of the information sources.
|
# This ensures the highest level of safety and reliability of the information sources.
|
||||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||||
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||||
"rag.rag.web.search.domain.filter_list",
|
"rag.web.search.domain.filter_list",
|
||||||
[
|
[
|
||||||
# "wikipedia.com",
|
# "wikipedia.com",
|
||||||
# "wikimedia.org",
|
# "wikimedia.org",
|
||||||
|
|
@ -1643,6 +1910,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
|
||||||
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
|
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BOCHA_SEARCH_API_KEY = PersistentConfig(
|
||||||
|
"BOCHA_SEARCH_API_KEY",
|
||||||
|
"rag.web.search.bocha_search_api_key",
|
||||||
|
os.getenv("BOCHA_SEARCH_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
SERPSTACK_API_KEY = PersistentConfig(
|
SERPSTACK_API_KEY = PersistentConfig(
|
||||||
"SERPSTACK_API_KEY",
|
"SERPSTACK_API_KEY",
|
||||||
"rag.web.search.serpstack_api_key",
|
"rag.web.search.serpstack_api_key",
|
||||||
|
|
@ -1691,6 +1964,18 @@ SEARCHAPI_ENGINE = PersistentConfig(
|
||||||
os.getenv("SEARCHAPI_ENGINE", ""),
|
os.getenv("SEARCHAPI_ENGINE", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SERPAPI_API_KEY = PersistentConfig(
|
||||||
|
"SERPAPI_API_KEY",
|
||||||
|
"rag.web.search.serpapi_api_key",
|
||||||
|
os.getenv("SERPAPI_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
SERPAPI_ENGINE = PersistentConfig(
|
||||||
|
"SERPAPI_ENGINE",
|
||||||
|
"rag.web.search.serpapi_engine",
|
||||||
|
os.getenv("SERPAPI_ENGINE", ""),
|
||||||
|
)
|
||||||
|
|
||||||
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
|
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
|
||||||
"BING_SEARCH_V7_ENDPOINT",
|
"BING_SEARCH_V7_ENDPOINT",
|
||||||
"rag.web.search.bing_search_v7_endpoint",
|
"rag.web.search.bing_search_v7_endpoint",
|
||||||
|
|
@ -1705,6 +1990,17 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
|
||||||
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EXA_API_KEY = PersistentConfig(
|
||||||
|
"EXA_API_KEY",
|
||||||
|
"rag.web.search.exa_api_key",
|
||||||
|
os.getenv("EXA_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
PERPLEXITY_API_KEY = PersistentConfig(
|
||||||
|
"PERPLEXITY_API_KEY",
|
||||||
|
"rag.web.search.perplexity_api_key",
|
||||||
|
os.getenv("PERPLEXITY_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||||
"RAG_WEB_SEARCH_RESULT_COUNT",
|
"RAG_WEB_SEARCH_RESULT_COUNT",
|
||||||
|
|
@ -1718,6 +2014,35 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||||
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAG_WEB_LOADER_ENGINE = PersistentConfig(
|
||||||
|
"RAG_WEB_LOADER_ENGINE",
|
||||||
|
"rag.web.loader.engine",
|
||||||
|
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
|
||||||
|
)
|
||||||
|
|
||||||
|
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||||
|
"RAG_WEB_SEARCH_TRUST_ENV",
|
||||||
|
"rag.web.search.trust_env",
|
||||||
|
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||||
|
"PLAYWRIGHT_WS_URI",
|
||||||
|
"rag.web.loader.engine.playwright.ws.uri",
|
||||||
|
os.environ.get("PLAYWRIGHT_WS_URI", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
FIRECRAWL_API_KEY = PersistentConfig(
|
||||||
|
"FIRECRAWL_API_KEY",
|
||||||
|
"firecrawl.api_key",
|
||||||
|
os.environ.get("FIRECRAWL_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
FIRECRAWL_API_BASE_URL = PersistentConfig(
|
||||||
|
"FIRECRAWL_API_BASE_URL",
|
||||||
|
"firecrawl.api_url",
|
||||||
|
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Images
|
# Images
|
||||||
|
|
@ -1929,6 +2254,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
|
||||||
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
|
||||||
|
"IMAGES_GEMINI_API_BASE_URL",
|
||||||
|
"image_generation.gemini.api_base_url",
|
||||||
|
os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
|
||||||
|
)
|
||||||
|
IMAGES_GEMINI_API_KEY = PersistentConfig(
|
||||||
|
"IMAGES_GEMINI_API_KEY",
|
||||||
|
"image_generation.gemini.api_key",
|
||||||
|
os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
|
||||||
|
)
|
||||||
|
|
||||||
IMAGE_SIZE = PersistentConfig(
|
IMAGE_SIZE = PersistentConfig(
|
||||||
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||||
)
|
)
|
||||||
|
|
@ -1960,6 +2296,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
|
||||||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add Deepgram configuration
|
||||||
|
DEEPGRAM_API_KEY = PersistentConfig(
|
||||||
|
"DEEPGRAM_API_KEY",
|
||||||
|
"audio.stt.deepgram.api_key",
|
||||||
|
os.getenv("DEEPGRAM_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||||
|
|
@ -2099,7 +2441,7 @@ LDAP_SEARCH_BASE = PersistentConfig(
|
||||||
LDAP_SEARCH_FILTERS = PersistentConfig(
|
LDAP_SEARCH_FILTERS = PersistentConfig(
|
||||||
"LDAP_SEARCH_FILTER",
|
"LDAP_SEARCH_FILTER",
|
||||||
"ldap.server.search_filter",
|
"ldap.server.search_filter",
|
||||||
os.environ.get("LDAP_SEARCH_FILTER", ""),
|
os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")),
|
||||||
)
|
)
|
||||||
|
|
||||||
LDAP_USE_TLS = PersistentConfig(
|
LDAP_USE_TLS = PersistentConfig(
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
|
|
||||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format (e.g., JPG, PNG, PDF, TXT) and try again."
|
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
||||||
|
|
||||||
NOT_FOUND = "We could not find what you're looking for :/"
|
NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
|
|
|
||||||
|
|
@ -65,10 +65,8 @@ except Exception:
|
||||||
# LOGGING
|
# LOGGING
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
|
||||||
|
|
||||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||||
if GLOBAL_LOG_LEVEL in log_levels:
|
if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||||
else:
|
else:
|
||||||
GLOBAL_LOG_LEVEL = "INFO"
|
GLOBAL_LOG_LEVEL = "INFO"
|
||||||
|
|
@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||||
|
|
||||||
if "cuda_error" in locals():
|
if "cuda_error" in locals():
|
||||||
log.exception(cuda_error)
|
log.exception(cuda_error)
|
||||||
|
del cuda_error
|
||||||
|
|
||||||
log_sources = [
|
log_sources = [
|
||||||
"AUDIO",
|
"AUDIO",
|
||||||
|
|
@ -92,6 +91,7 @@ log_sources = [
|
||||||
"RAG",
|
"RAG",
|
||||||
"WEBHOOK",
|
"WEBHOOK",
|
||||||
"SOCKET",
|
"SOCKET",
|
||||||
|
"OAUTH",
|
||||||
]
|
]
|
||||||
|
|
||||||
SRC_LOG_LEVELS = {}
|
SRC_LOG_LEVELS = {}
|
||||||
|
|
@ -99,7 +99,7 @@ SRC_LOG_LEVELS = {}
|
||||||
for source in log_sources:
|
for source in log_sources:
|
||||||
log_env_var = source + "_LOG_LEVEL"
|
log_env_var = source + "_LOG_LEVEL"
|
||||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
|
||||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||||
|
|
||||||
|
|
@ -112,6 +112,7 @@ if WEBUI_NAME != "Open WebUI":
|
||||||
|
|
||||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||||
|
|
||||||
|
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# ENV (dev,test,prod)
|
# ENV (dev,test,prod)
|
||||||
|
|
@ -356,14 +357,22 @@ WEBUI_SECRET_KEY = os.environ.get(
|
||||||
), # DEPRECATED: remove at next major version
|
), # DEPRECATED: remove at next major version
|
||||||
)
|
)
|
||||||
|
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
|
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax")
|
||||||
"WEBUI_SESSION_COOKIE_SAME_SITE",
|
|
||||||
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
|
WEBUI_SESSION_COOKIE_SECURE = (
|
||||||
|
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
|
WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get(
|
||||||
"WEBUI_SESSION_COOKIE_SECURE",
|
"WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE
|
||||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
|
)
|
||||||
|
|
||||||
|
WEBUI_AUTH_COOKIE_SECURE = (
|
||||||
|
os.environ.get(
|
||||||
|
"WEBUI_AUTH_COOKIE_SECURE",
|
||||||
|
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"),
|
||||||
|
).lower()
|
||||||
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||||
|
|
@ -376,6 +385,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
||||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||||
|
|
||||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||||
|
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||||
|
|
||||||
|
|
@ -387,19 +397,20 @@ else:
|
||||||
except Exception:
|
except Exception:
|
||||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
|
||||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
|
||||||
|
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
|
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
|
||||||
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# OFFLINE_MODE
|
# OFFLINE_MODE
|
||||||
|
|
@ -409,3 +420,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||||
|
|
||||||
if OFFLINE_MODE:
|
if OFFLINE_MODE:
|
||||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# AUDIT LOGGING
|
||||||
|
####################################
|
||||||
|
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||||
|
# Where to store log file
|
||||||
|
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||||
|
# Maximum size of a file before rotating into a new log file
|
||||||
|
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||||
|
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||||
|
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||||
|
try:
|
||||||
|
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||||
|
except ValueError:
|
||||||
|
MAX_BODY_LOG_SIZE = 2048
|
||||||
|
|
||||||
|
# Comma separated list for urls to exclude from audit
|
||||||
|
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||||
|
","
|
||||||
|
)
|
||||||
|
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||||
|
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import AsyncGenerator, Generator, Iterator
|
from typing import AsyncGenerator, Generator, Iterator
|
||||||
|
|
@ -76,10 +77,12 @@ async def get_function_models(request):
|
||||||
if hasattr(function_module, "pipes"):
|
if hasattr(function_module, "pipes"):
|
||||||
sub_pipes = []
|
sub_pipes = []
|
||||||
|
|
||||||
# Check if pipes is a function or a list
|
# Handle pipes being a list, sync function, or async function
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if callable(function_module.pipes):
|
if callable(function_module.pipes):
|
||||||
|
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||||
|
sub_pipes = await function_module.pipes()
|
||||||
|
else:
|
||||||
sub_pipes = function_module.pipes()
|
sub_pipes = function_module.pipes()
|
||||||
else:
|
else:
|
||||||
sub_pipes = function_module.pipes
|
sub_pipes = function_module.pipes
|
||||||
|
|
@ -250,7 +253,7 @@ async def generate_function_chat_completion(
|
||||||
|
|
||||||
params = model_info.params.model_dump()
|
params = model_info.params.model_dump()
|
||||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||||
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||||
|
|
||||||
pipe_id = get_pipe_id(form_data)
|
pipe_id = get_pipe_id(form_data)
|
||||||
function_module = get_function_module_by_id(request, pipe_id)
|
function_module = get_function_module_by_id(request, pipe_id)
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.responses import Response, StreamingResponse
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils import logger
|
||||||
|
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||||
|
from open_webui.utils.logger import start_logger
|
||||||
from open_webui.socket.main import (
|
from open_webui.socket.main import (
|
||||||
app as socket_app,
|
app as socket_app,
|
||||||
periodic_usage_pool_cleanup,
|
periodic_usage_pool_cleanup,
|
||||||
|
|
@ -88,15 +91,34 @@ from open_webui.models.models import Models
|
||||||
from open_webui.models.users import UserModel, Users
|
from open_webui.models.users import UserModel, Users
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
|
LICENSE_KEY,
|
||||||
# Ollama
|
# Ollama
|
||||||
ENABLE_OLLAMA_API,
|
ENABLE_OLLAMA_API,
|
||||||
OLLAMA_BASE_URLS,
|
OLLAMA_BASE_URLS,
|
||||||
OLLAMA_API_CONFIGS,
|
OLLAMA_API_CONFIGS,
|
||||||
# OpenAI
|
# OpenAI
|
||||||
ENABLE_OPENAI_API,
|
ENABLE_OPENAI_API,
|
||||||
|
ONEDRIVE_CLIENT_ID,
|
||||||
OPENAI_API_BASE_URLS,
|
OPENAI_API_BASE_URLS,
|
||||||
OPENAI_API_KEYS,
|
OPENAI_API_KEYS,
|
||||||
OPENAI_API_CONFIGS,
|
OPENAI_API_CONFIGS,
|
||||||
|
# Direct Connections
|
||||||
|
ENABLE_DIRECT_CONNECTIONS,
|
||||||
|
# Code Execution
|
||||||
|
CODE_EXECUTION_ENGINE,
|
||||||
|
CODE_EXECUTION_JUPYTER_URL,
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH,
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||||
|
CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||||
|
ENABLE_CODE_INTERPRETER,
|
||||||
|
CODE_INTERPRETER_ENGINE,
|
||||||
|
CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||||
|
CODE_INTERPRETER_JUPYTER_URL,
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH,
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||||
|
CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||||
# Image
|
# Image
|
||||||
AUTOMATIC1111_API_AUTH,
|
AUTOMATIC1111_API_AUTH,
|
||||||
AUTOMATIC1111_BASE_URL,
|
AUTOMATIC1111_BASE_URL,
|
||||||
|
|
@ -115,6 +137,8 @@ from open_webui.config import (
|
||||||
IMAGE_STEPS,
|
IMAGE_STEPS,
|
||||||
IMAGES_OPENAI_API_BASE_URL,
|
IMAGES_OPENAI_API_BASE_URL,
|
||||||
IMAGES_OPENAI_API_KEY,
|
IMAGES_OPENAI_API_KEY,
|
||||||
|
IMAGES_GEMINI_API_BASE_URL,
|
||||||
|
IMAGES_GEMINI_API_KEY,
|
||||||
# Audio
|
# Audio
|
||||||
AUDIO_STT_ENGINE,
|
AUDIO_STT_ENGINE,
|
||||||
AUDIO_STT_MODEL,
|
AUDIO_STT_MODEL,
|
||||||
|
|
@ -129,12 +153,19 @@ from open_webui.config import (
|
||||||
AUDIO_TTS_VOICE,
|
AUDIO_TTS_VOICE,
|
||||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||||
|
PLAYWRIGHT_WS_URI,
|
||||||
|
FIRECRAWL_API_BASE_URL,
|
||||||
|
FIRECRAWL_API_KEY,
|
||||||
|
RAG_WEB_LOADER_ENGINE,
|
||||||
WHISPER_MODEL,
|
WHISPER_MODEL,
|
||||||
|
DEEPGRAM_API_KEY,
|
||||||
WHISPER_MODEL_AUTO_UPDATE,
|
WHISPER_MODEL_AUTO_UPDATE,
|
||||||
WHISPER_MODEL_DIR,
|
WHISPER_MODEL_DIR,
|
||||||
# Retrieval
|
# Retrieval
|
||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
DEFAULT_RAG_TEMPLATE,
|
DEFAULT_RAG_TEMPLATE,
|
||||||
|
RAG_FULL_CONTEXT,
|
||||||
|
BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
|
|
@ -155,6 +186,8 @@ from open_webui.config import (
|
||||||
CONTENT_EXTRACTION_ENGINE,
|
CONTENT_EXTRACTION_ENGINE,
|
||||||
TIKA_SERVER_URL,
|
TIKA_SERVER_URL,
|
||||||
DOCLING_SERVER_URL,
|
DOCLING_SERVER_URL,
|
||||||
|
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
|
DOCUMENT_INTELLIGENCE_KEY,
|
||||||
RAG_TOP_K,
|
RAG_TOP_K,
|
||||||
RAG_TEXT_SPLITTER,
|
RAG_TEXT_SPLITTER,
|
||||||
TIKTOKEN_ENCODING_NAME,
|
TIKTOKEN_ENCODING_NAME,
|
||||||
|
|
@ -163,12 +196,16 @@ from open_webui.config import (
|
||||||
YOUTUBE_LOADER_PROXY_URL,
|
YOUTUBE_LOADER_PROXY_URL,
|
||||||
# Retrieval (Web Search)
|
# Retrieval (Web Search)
|
||||||
RAG_WEB_SEARCH_ENGINE,
|
RAG_WEB_SEARCH_ENGINE,
|
||||||
|
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
|
RAG_WEB_SEARCH_TRUST_ENV,
|
||||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
JINA_API_KEY,
|
JINA_API_KEY,
|
||||||
SEARCHAPI_API_KEY,
|
SEARCHAPI_API_KEY,
|
||||||
SEARCHAPI_ENGINE,
|
SEARCHAPI_ENGINE,
|
||||||
|
SERPAPI_API_KEY,
|
||||||
|
SERPAPI_ENGINE,
|
||||||
SEARXNG_QUERY_URL,
|
SEARXNG_QUERY_URL,
|
||||||
SERPER_API_KEY,
|
SERPER_API_KEY,
|
||||||
SERPLY_API_KEY,
|
SERPLY_API_KEY,
|
||||||
|
|
@ -178,17 +215,22 @@ from open_webui.config import (
|
||||||
BING_SEARCH_V7_ENDPOINT,
|
BING_SEARCH_V7_ENDPOINT,
|
||||||
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||||
BRAVE_SEARCH_API_KEY,
|
BRAVE_SEARCH_API_KEY,
|
||||||
|
EXA_API_KEY,
|
||||||
|
PERPLEXITY_API_KEY,
|
||||||
KAGI_SEARCH_API_KEY,
|
KAGI_SEARCH_API_KEY,
|
||||||
MOJEEK_SEARCH_API_KEY,
|
MOJEEK_SEARCH_API_KEY,
|
||||||
|
BOCHA_SEARCH_API_KEY,
|
||||||
GOOGLE_PSE_API_KEY,
|
GOOGLE_PSE_API_KEY,
|
||||||
GOOGLE_PSE_ENGINE_ID,
|
GOOGLE_PSE_ENGINE_ID,
|
||||||
GOOGLE_DRIVE_CLIENT_ID,
|
GOOGLE_DRIVE_CLIENT_ID,
|
||||||
GOOGLE_DRIVE_API_KEY,
|
GOOGLE_DRIVE_API_KEY,
|
||||||
|
ONEDRIVE_CLIENT_ID,
|
||||||
ENABLE_RAG_HYBRID_SEARCH,
|
ENABLE_RAG_HYBRID_SEARCH,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
ENABLE_RAG_WEB_SEARCH,
|
ENABLE_RAG_WEB_SEARCH,
|
||||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
ENABLE_ONEDRIVE_INTEGRATION,
|
||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
# WebUI
|
# WebUI
|
||||||
WEBUI_AUTH,
|
WEBUI_AUTH,
|
||||||
|
|
@ -252,6 +294,7 @@ from open_webui.config import (
|
||||||
TASK_MODEL,
|
TASK_MODEL,
|
||||||
TASK_MODEL_EXTERNAL,
|
TASK_MODEL_EXTERNAL,
|
||||||
ENABLE_TAGS_GENERATION,
|
ENABLE_TAGS_GENERATION,
|
||||||
|
ENABLE_TITLE_GENERATION,
|
||||||
ENABLE_SEARCH_QUERY_GENERATION,
|
ENABLE_SEARCH_QUERY_GENERATION,
|
||||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||||
|
|
@ -266,8 +309,11 @@ from open_webui.config import (
|
||||||
reset_config,
|
reset_config,
|
||||||
)
|
)
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
AUDIT_EXCLUDED_PATHS,
|
||||||
|
AUDIT_LOG_LEVEL,
|
||||||
CHANGELOG,
|
CHANGELOG,
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
|
MAX_BODY_LOG_SIZE,
|
||||||
SAFE_MODE,
|
SAFE_MODE,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
VERSION,
|
VERSION,
|
||||||
|
|
@ -298,15 +344,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
|
get_license_data,
|
||||||
decode_token,
|
decode_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
get_verified_user,
|
get_verified_user,
|
||||||
)
|
)
|
||||||
from open_webui.utils.oauth import oauth_manager
|
from open_webui.utils.oauth import OAuthManager
|
||||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||||
|
|
||||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||||
|
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
print("SAFE MODE ENABLED")
|
print("SAFE MODE ENABLED")
|
||||||
Functions.deactivate_all_functions()
|
Functions.deactivate_all_functions()
|
||||||
|
|
@ -322,6 +370,10 @@ class SPAStaticFiles(StaticFiles):
|
||||||
return await super().get_response(path, scope)
|
return await super().get_response(path, scope)
|
||||||
except (HTTPException, StarletteHTTPException) as ex:
|
except (HTTPException, StarletteHTTPException) as ex:
|
||||||
if ex.status_code == 404:
|
if ex.status_code == 404:
|
||||||
|
if path.endswith(".js"):
|
||||||
|
# Return 404 for javascript files
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
return await super().get_response("index.html", scope)
|
return await super().get_response("index.html", scope)
|
||||||
else:
|
else:
|
||||||
raise ex
|
raise ex
|
||||||
|
|
@ -329,12 +381,12 @@ class SPAStaticFiles(StaticFiles):
|
||||||
|
|
||||||
print(
|
print(
|
||||||
rf"""
|
rf"""
|
||||||
___ __ __ _ _ _ ___
|
██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗
|
||||||
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
|
██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║
|
||||||
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
|
██║ ██║██████╔╝█████╗ ██╔██╗ ██║ ██║ █╗ ██║█████╗ ██████╔╝██║ ██║██║
|
||||||
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
|
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║ ██║███╗██║██╔══╝ ██╔══██╗██║ ██║██║
|
||||||
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|
╚██████╔╝██║ ███████╗██║ ╚████║ ╚███╔███╔╝███████╗██████╔╝╚██████╔╝██║
|
||||||
|_|
|
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝ ╚══╝╚══╝ ╚══════╝╚═════╝ ╚═════╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
v{VERSION} - building the best open-source AI user interface.
|
v{VERSION} - building the best open-source AI user interface.
|
||||||
|
|
@ -346,9 +398,13 @@ https://github.com/open-webui/open-webui
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
start_logger()
|
||||||
if RESET_CONFIG_ON_START:
|
if RESET_CONFIG_ON_START:
|
||||||
reset_config()
|
reset_config()
|
||||||
|
|
||||||
|
if LICENSE_KEY:
|
||||||
|
get_license_data(app, LICENSE_KEY)
|
||||||
|
|
||||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
@ -360,8 +416,12 @@ app = FastAPI(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
oauth_manager = OAuthManager(app)
|
||||||
|
|
||||||
app.state.config = AppConfig()
|
app.state.config = AppConfig()
|
||||||
|
|
||||||
|
app.state.WEBUI_NAME = WEBUI_NAME
|
||||||
|
app.state.LICENSE_METADATA = None
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
|
|
@ -389,6 +449,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
||||||
|
|
||||||
app.state.OPENAI_MODELS = {}
|
app.state.OPENAI_MODELS = {}
|
||||||
|
|
||||||
|
########################################
|
||||||
|
#
|
||||||
|
# DIRECT CONNECTIONS
|
||||||
|
#
|
||||||
|
########################################
|
||||||
|
|
||||||
|
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
# WEBUI
|
# WEBUI
|
||||||
|
|
@ -455,10 +523,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
||||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||||
|
|
||||||
|
app.state.USER_COUNT = None
|
||||||
app.state.TOOLS = {}
|
app.state.TOOLS = {}
|
||||||
app.state.FUNCTIONS = {}
|
app.state.FUNCTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
# RETRIEVAL
|
# RETRIEVAL
|
||||||
|
|
@ -471,6 +539,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||||
|
|
||||||
|
|
||||||
|
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||||
|
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||||
|
|
@ -479,6 +550,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||||
|
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||||
|
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||||
|
|
||||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||||
|
|
@ -506,15 +579,20 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
||||||
|
|
||||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||||
|
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||||
|
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||||
|
)
|
||||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||||
|
|
||||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||||
|
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||||
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
|
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
|
||||||
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
||||||
|
app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
|
||||||
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||||
|
|
@ -522,12 +600,21 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
||||||
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
||||||
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
||||||
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
||||||
|
app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY
|
||||||
|
app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE
|
||||||
app.state.config.JINA_API_KEY = JINA_API_KEY
|
app.state.config.JINA_API_KEY = JINA_API_KEY
|
||||||
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||||
|
app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||||
|
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||||
|
|
||||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||||
|
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
|
||||||
|
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||||
|
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||||
|
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||||
|
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||||
|
|
||||||
app.state.EMBEDDING_FUNCTION = None
|
app.state.EMBEDDING_FUNCTION = None
|
||||||
app.state.ef = None
|
app.state.ef = None
|
||||||
|
|
@ -569,6 +656,34 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
########################################
|
||||||
|
#
|
||||||
|
# CODE EXECUTION
|
||||||
|
#
|
||||||
|
########################################
|
||||||
|
|
||||||
|
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
|
||||||
|
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
|
||||||
|
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
|
||||||
|
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||||
|
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||||
|
|
||||||
|
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
|
||||||
|
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
|
||||||
|
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
|
||||||
|
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
|
||||||
|
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||||
|
)
|
||||||
|
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
|
|
@ -583,6 +698,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
|
||||||
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||||
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||||
|
|
||||||
|
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
||||||
|
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||||
|
|
||||||
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
||||||
|
|
||||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||||
|
|
@ -611,6 +729,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||||
|
|
||||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||||
|
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
|
||||||
|
|
||||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||||
|
|
@ -645,6 +764,7 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
||||||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
||||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||||
|
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
|
||||||
|
|
||||||
|
|
||||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
@ -753,6 +873,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||||
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
|
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
|
||||||
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
|
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
|
||||||
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
||||||
|
|
||||||
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
|
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
|
||||||
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
|
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
|
||||||
|
|
||||||
|
|
@ -781,6 +902,19 @@ app.include_router(
|
||||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
|
||||||
|
audit_level = AuditLevel.NONE
|
||||||
|
|
||||||
|
if audit_level != AuditLevel.NONE:
|
||||||
|
app.add_middleware(
|
||||||
|
AuditLoggingMiddleware,
|
||||||
|
audit_level=audit_level,
|
||||||
|
excluded_paths=AUDIT_EXCLUDED_PATHS,
|
||||||
|
max_body_size=MAX_BODY_LOG_SIZE,
|
||||||
|
)
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# Chat Endpoints
|
# Chat Endpoints
|
||||||
|
|
@ -813,7 +947,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
return filtered_models
|
return filtered_models
|
||||||
|
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
|
|
||||||
# Filter out filter pipelines
|
# Filter out filter pipelines
|
||||||
models = [
|
models = [
|
||||||
|
|
@ -842,7 +976,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@app.get("/api/models/base")
|
@app.get("/api/models/base")
|
||||||
async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
||||||
models = await get_all_base_models(request)
|
models = await get_all_base_models(request, user=user)
|
||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -853,14 +987,19 @@ async def chat_completion(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
|
model_item = form_data.pop("model_item", {})
|
||||||
tasks = form_data.pop("background_tasks", None)
|
tasks = form_data.pop("background_tasks", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not model_item.get("direct", False):
|
||||||
model_id = form_data.get("model", None)
|
model_id = form_data.get("model", None)
|
||||||
if model_id not in request.app.state.MODELS:
|
if model_id not in request.app.state.MODELS:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
model = request.app.state.MODELS[model_id]
|
model = request.app.state.MODELS[model_id]
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||||
|
|
@ -868,6 +1007,12 @@ async def chat_completion(
|
||||||
check_model_access(user, model)
|
check_model_access(user, model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
else:
|
||||||
|
model = model_item
|
||||||
|
model_info = None
|
||||||
|
|
||||||
|
request.state.direct = True
|
||||||
|
request.state.model = model
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
|
|
@ -877,13 +1022,30 @@ async def chat_completion(
|
||||||
"tool_ids": form_data.get("tool_ids", None),
|
"tool_ids": form_data.get("tool_ids", None),
|
||||||
"files": form_data.get("files", None),
|
"files": form_data.get("files", None),
|
||||||
"features": form_data.get("features", None),
|
"features": form_data.get("features", None),
|
||||||
|
"variables": form_data.get("variables", None),
|
||||||
|
"model": model,
|
||||||
|
"direct": model_item.get("direct", False),
|
||||||
|
**(
|
||||||
|
{"function_calling": "native"}
|
||||||
|
if form_data.get("params", {}).get("function_calling") == "native"
|
||||||
|
or (
|
||||||
|
model_info
|
||||||
|
and model_info.params.model_dump().get("function_calling")
|
||||||
|
== "native"
|
||||||
|
)
|
||||||
|
else {}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.state.metadata = metadata
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
form_data, events = await process_chat_payload(
|
form_data, metadata, events = await process_chat_payload(
|
||||||
request, form_data, metadata, user, model
|
request, form_data, user, metadata, model
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
log.debug(f"Error processing chat payload: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=str(e),
|
detail=str(e),
|
||||||
|
|
@ -891,8 +1053,9 @@ async def chat_completion(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await chat_completion_handler(request, form_data, user)
|
response = await chat_completion_handler(request, form_data, user)
|
||||||
|
|
||||||
return await process_chat_response(
|
return await process_chat_response(
|
||||||
request, response, form_data, user, events, metadata, tasks
|
request, response, form_data, user, metadata, model, events, tasks
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -911,6 +1074,12 @@ async def chat_completed(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
model_item = form_data.pop("model_item", {})
|
||||||
|
|
||||||
|
if model_item.get("direct", False):
|
||||||
|
request.state.direct = True
|
||||||
|
request.state.model = model_item
|
||||||
|
|
||||||
return await chat_completed_handler(request, form_data, user)
|
return await chat_completed_handler(request, form_data, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -924,6 +1093,12 @@ async def chat_action(
|
||||||
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
model_item = form_data.pop("model_item", {})
|
||||||
|
|
||||||
|
if model_item.get("direct", False):
|
||||||
|
request.state.direct = True
|
||||||
|
request.state.model = model_item
|
||||||
|
|
||||||
return await chat_action_handler(request, action_id, form_data, user)
|
return await chat_action_handler(request, action_id, form_data, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -969,15 +1144,16 @@ async def get_app_config(request: Request):
|
||||||
if data is not None and "id" in data:
|
if data is not None and "id" in data:
|
||||||
user = Users.get_user_by_id(data["id"])
|
user = Users.get_user_by_id(data["id"])
|
||||||
|
|
||||||
onboarding = False
|
|
||||||
if user is None:
|
|
||||||
user_count = Users.get_num_users()
|
user_count = Users.get_num_users()
|
||||||
|
onboarding = False
|
||||||
|
|
||||||
|
if user is None:
|
||||||
onboarding = user_count == 0
|
onboarding = user_count == 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**({"onboarding": True} if onboarding else {}),
|
**({"onboarding": True} if onboarding else {}),
|
||||||
"status": True,
|
"status": True,
|
||||||
"name": WEBUI_NAME,
|
"name": app.state.WEBUI_NAME,
|
||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
"default_locale": str(DEFAULT_LOCALE),
|
"default_locale": str(DEFAULT_LOCALE),
|
||||||
"oauth": {
|
"oauth": {
|
||||||
|
|
@ -996,27 +1172,31 @@ async def get_app_config(request: Request):
|
||||||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
|
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||||
|
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||||
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
||||||
|
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||||
}
|
}
|
||||||
if user is not None
|
if user is not None
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"google_drive": {
|
|
||||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
|
||||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
|
||||||
},
|
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"default_models": app.state.config.DEFAULT_MODELS,
|
"default_models": app.state.config.DEFAULT_MODELS,
|
||||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
|
"user_count": user_count,
|
||||||
|
"code": {
|
||||||
|
"engine": app.state.config.CODE_EXECUTION_ENGINE,
|
||||||
|
},
|
||||||
"audio": {
|
"audio": {
|
||||||
"tts": {
|
"tts": {
|
||||||
"engine": app.state.config.TTS_ENGINE,
|
"engine": app.state.config.TTS_ENGINE,
|
||||||
|
|
@ -1032,6 +1212,19 @@ async def get_app_config(request: Request):
|
||||||
"max_count": app.state.config.FILE_MAX_COUNT,
|
"max_count": app.state.config.FILE_MAX_COUNT,
|
||||||
},
|
},
|
||||||
"permissions": {**app.state.config.USER_PERMISSIONS},
|
"permissions": {**app.state.config.USER_PERMISSIONS},
|
||||||
|
"google_drive": {
|
||||||
|
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||||
|
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||||
|
},
|
||||||
|
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
|
||||||
|
"license_metadata": app.state.LICENSE_METADATA,
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"active_entries": app.state.USER_COUNT,
|
||||||
|
}
|
||||||
|
if user.role == "admin"
|
||||||
|
else {}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
if user is not None
|
if user is not None
|
||||||
else {}
|
else {}
|
||||||
|
|
@ -1065,7 +1258,7 @@ async def get_app_version():
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/version/updates")
|
@app.get("/api/version/updates")
|
||||||
async def get_app_latest_release_version():
|
async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
||||||
if OFFLINE_MODE:
|
if OFFLINE_MODE:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Offline mode is enabled, returning current version as latest version"
|
f"Offline mode is enabled, returning current version as latest version"
|
||||||
|
|
@ -1109,7 +1302,7 @@ if len(OAUTH_PROVIDERS) > 0:
|
||||||
|
|
||||||
@app.get("/oauth/{provider}/login")
|
@app.get("/oauth/{provider}/login")
|
||||||
async def oauth_login(provider: str, request: Request):
|
async def oauth_login(provider: str, request: Request):
|
||||||
return await oauth_manager.handle_login(provider, request)
|
return await oauth_manager.handle_login(request, provider)
|
||||||
|
|
||||||
|
|
||||||
# OAuth login logic is as follows:
|
# OAuth login logic is as follows:
|
||||||
|
|
@ -1120,14 +1313,14 @@ async def oauth_login(provider: str, request: Request):
|
||||||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||||
@app.get("/oauth/{provider}/callback")
|
@app.get("/oauth/{provider}/callback")
|
||||||
async def oauth_callback(provider: str, request: Request, response: Response):
|
async def oauth_callback(provider: str, request: Request, response: Response):
|
||||||
return await oauth_manager.handle_callback(provider, request, response)
|
return await oauth_manager.handle_callback(request, provider, response)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/manifest.json")
|
@app.get("/manifest.json")
|
||||||
async def get_manifest_json():
|
async def get_manifest_json():
|
||||||
return {
|
return {
|
||||||
"name": WEBUI_NAME,
|
"name": app.state.WEBUI_NAME,
|
||||||
"short_name": WEBUI_NAME,
|
"short_name": app.state.WEBUI_NAME,
|
||||||
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
||||||
"start_url": "/",
|
"start_url": "/",
|
||||||
"display": "standalone",
|
"display": "standalone",
|
||||||
|
|
@ -1154,8 +1347,8 @@ async def get_manifest_json():
|
||||||
async def get_opensearch_xml():
|
async def get_opensearch_xml():
|
||||||
xml_content = rf"""
|
xml_content = rf"""
|
||||||
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
|
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
|
||||||
<ShortName>{WEBUI_NAME}</ShortName>
|
<ShortName>{app.state.WEBUI_NAME}</ShortName>
|
||||||
<Description>Search {WEBUI_NAME}</Description>
|
<Description>Search {app.state.WEBUI_NAME}</Description>
|
||||||
<InputEncoding>UTF-8</InputEncoding>
|
<InputEncoding>UTF-8</InputEncoding>
|
||||||
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
|
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
|
||||||
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
|
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -5,7 +6,7 @@ from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.tags import TagModel, Tag, Tags
|
from open_webui.models.tags import TagModel, Tag, Tags
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||||
|
|
@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
|
||||||
# Chat DB Schema
|
# Chat DB Schema
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
class Chat(Base):
|
class Chat(Base):
|
||||||
__tablename__ = "chat"
|
__tablename__ = "chat"
|
||||||
|
|
@ -470,7 +474,7 @@ class ChatTable:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# it is possible that the shared link was deleted. hence,
|
# it is possible that the shared link was deleted. hence,
|
||||||
# we check if the chat is still shared by checkng if a chat with the share_id exists
|
# we check if the chat is still shared by checking if a chat with the share_id exists
|
||||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
|
|
@ -670,7 +674,7 @@ class ChatTable:
|
||||||
# Perform pagination at the SQL level
|
# Perform pagination at the SQL level
|
||||||
all_chats = query.offset(skip).limit(limit).all()
|
all_chats = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
print(len(all_chats))
|
log.info(f"The number of chats: {len(all_chats)}")
|
||||||
|
|
||||||
# Validate and return chats
|
# Validate and return chats
|
||||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||||
|
|
@ -731,7 +735,7 @@ class ChatTable:
|
||||||
query = db.query(Chat).filter_by(user_id=user_id)
|
query = db.query(Chat).filter_by(user_id=user_id)
|
||||||
tag_id = tag_name.replace(" ", "_").lower()
|
tag_id = tag_name.replace(" ", "_").lower()
|
||||||
|
|
||||||
print(db.bind.dialect.name)
|
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||||
if db.bind.dialect.name == "sqlite":
|
if db.bind.dialect.name == "sqlite":
|
||||||
# SQLite JSON1 querying for tags within the meta JSON field
|
# SQLite JSON1 querying for tags within the meta JSON field
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
|
|
@ -752,7 +756,7 @@ class ChatTable:
|
||||||
)
|
)
|
||||||
|
|
||||||
all_chats = query.all()
|
all_chats = query.all()
|
||||||
print("all_chats", all_chats)
|
log.debug(f"all_chats: {all_chats}")
|
||||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||||
|
|
||||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||||
|
|
@ -810,7 +814,7 @@ class ChatTable:
|
||||||
count = query.count()
|
count = query.count()
|
||||||
|
|
||||||
# Debugging output for inspection
|
# Debugging output for inspection
|
||||||
print(f"Count of chats for tag '{tag_name}':", count)
|
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ class FeedbackTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error creating a new feedback: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ class FilesTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating tool: {e}")
|
log.exception(f"Error inserting a new file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class FolderTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error inserting a new folder: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_folder_by_id_and_user_id(
|
def get_folder_by_id_and_user_id(
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ class FunctionsTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating tool: {e}")
|
log.exception(f"Error creating a new function: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||||
|
|
@ -170,7 +170,7 @@ class FunctionsTable:
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
return function.valves if function.valves else {}
|
return function.valves if function.valves else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_function_valves_by_id(
|
def update_function_valves_by_id(
|
||||||
|
|
@ -202,7 +202,9 @@ class FunctionsTable:
|
||||||
|
|
||||||
return user_settings["functions"]["valves"].get(id, {})
|
return user_settings["functions"]["valves"].get(id, {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(
|
||||||
|
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_valves_by_id_and_user_id(
|
def update_user_valves_by_id_and_user_id(
|
||||||
|
|
@ -225,7 +227,9 @@ class FunctionsTable:
|
||||||
|
|
||||||
return user_settings["functions"]["valves"][id]
|
return user_settings["functions"]["valves"][id]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(
|
||||||
|
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||||
|
|
|
||||||
5
backend/open_webui/models/models.py
Normal file → Executable file
|
|
@ -166,7 +166,7 @@ class ModelsTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to insert a new model: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_models(self) -> list[ModelModel]:
|
def get_all_models(self) -> list[ModelModel]:
|
||||||
|
|
@ -246,8 +246,7 @@ class ModelsTable:
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
return ModelModel.model_validate(model)
|
return ModelModel.model_validate(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_model_by_id(self, id: str) -> bool:
|
def delete_model_by_id(self, id: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ class TagTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error inserting a new tag: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tag_by_name_and_user_id(
|
def get_tag_by_name_and_user_id(
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ class ToolsTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating tool: {e}")
|
log.exception(f"Error creating a new tool: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||||
|
|
@ -175,7 +175,7 @@ class ToolsTable:
|
||||||
tool = db.get(Tool, id)
|
tool = db.get(Tool, id)
|
||||||
return tool.valves if tool.valves else {}
|
return tool.valves if tool.valves else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||||
|
|
@ -204,7 +204,9 @@ class ToolsTable:
|
||||||
|
|
||||||
return user_settings["tools"]["valves"].get(id, {})
|
return user_settings["tools"]["valves"].get(id, {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(
|
||||||
|
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_valves_by_id_and_user_id(
|
def update_user_valves_by_id_and_user_id(
|
||||||
|
|
@ -227,7 +229,9 @@ class ToolsTable:
|
||||||
|
|
||||||
return user_settings["tools"]["valves"][id]
|
return user_settings["tools"]["valves"][id]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(
|
||||||
|
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,24 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
user_settings = db.query(User).filter_by(id=id).first().settings
|
||||||
|
|
||||||
|
if user_settings is None:
|
||||||
|
user_settings = {}
|
||||||
|
|
||||||
|
user_settings.update(updated)
|
||||||
|
|
||||||
|
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
user = db.query(User).filter_by(id=id).first()
|
||||||
|
return UserModel.model_validate(user)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def delete_user_by_id(self, id: str) -> bool:
|
def delete_user_by_id(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
# Remove User from Groups
|
# Remove User from Groups
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import ftfy
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from langchain_community.document_loaders import (
|
from langchain_community.document_loaders import (
|
||||||
|
AzureAIDocumentIntelligenceLoader,
|
||||||
BSHTMLLoader,
|
BSHTMLLoader,
|
||||||
CSVLoader,
|
CSVLoader,
|
||||||
Docx2txtLoader,
|
Docx2txtLoader,
|
||||||
|
|
@ -76,6 +77,7 @@ known_source_ext = [
|
||||||
"jsx",
|
"jsx",
|
||||||
"hs",
|
"hs",
|
||||||
"lhs",
|
"lhs",
|
||||||
|
"json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -221,12 +223,33 @@ class Loader:
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
)
|
)
|
||||||
elif self.engine == "docling":
|
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||||
loader = DoclingLoader(
|
loader = DoclingLoader(
|
||||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.engine == "document_intelligence"
|
||||||
|
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||||
|
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||||
|
and (
|
||||||
|
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||||
|
or file_content_type
|
||||||
|
in [
|
||||||
|
"application/vnd.ms-excel",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"application/vnd.ms-powerpoint",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
loader = AzureAIDocumentIntelligenceLoader(
|
||||||
|
file_path=file_path,
|
||||||
|
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||||
|
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if file_ext == "pdf":
|
if file_ext == "pdf":
|
||||||
loader = PyPDFLoader(
|
loader = PyPDFLoader(
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from colbert.infra import ColBERTConfig
|
from colbert.infra import ColBERTConfig
|
||||||
from colbert.modeling.checkpoint import Checkpoint
|
from colbert.modeling.checkpoint import Checkpoint
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ColBERT:
|
class ColBERT:
|
||||||
def __init__(self, name, **kwargs) -> None:
|
def __init__(self, name, **kwargs) -> None:
|
||||||
print("ColBERT: Loading model", name)
|
log.info("ColBERT: Loading model", name)
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
DOCKER = kwargs.get("env") == "docker"
|
DOCKER = kwargs.get("env") == "docker"
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||||
|
|
@ -14,9 +15,16 @@ from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.config import VECTOR_DB
|
from open_webui.config import VECTOR_DB
|
||||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.misc import get_last_user_message
|
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
|
from open_webui.models.users import UserModel
|
||||||
|
from open_webui.models.files import Files
|
||||||
|
|
||||||
|
from open_webui.env import (
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
|
OFFLINE_MODE,
|
||||||
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
@ -61,9 +69,7 @@ class VectorSearchRetriever(BaseRetriever):
|
||||||
|
|
||||||
|
|
||||||
def query_doc(
|
def query_doc(
|
||||||
collection_name: str,
|
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
|
||||||
query_embedding: list[float],
|
|
||||||
k: int,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.search(
|
result = VECTOR_DB_CLIENT.search(
|
||||||
|
|
@ -77,7 +83,20 @@ def query_doc(
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc(collection_name: str, user: UserModel = None):
|
||||||
|
try:
|
||||||
|
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error getting doc {collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -134,47 +153,80 @@ def query_doc_with_hybrid_search(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def merge_and_sort_query_results(
|
def merge_get_results(get_results: list[dict]) -> dict:
|
||||||
query_results: list[dict], k: int, reverse: bool = False
|
|
||||||
) -> list[dict]:
|
|
||||||
# Initialize lists to store combined data
|
# Initialize lists to store combined data
|
||||||
combined_distances = []
|
|
||||||
combined_documents = []
|
combined_documents = []
|
||||||
combined_metadatas = []
|
combined_metadatas = []
|
||||||
|
combined_ids = []
|
||||||
|
|
||||||
for data in query_results:
|
for data in get_results:
|
||||||
combined_distances.extend(data["distances"][0])
|
|
||||||
combined_documents.extend(data["documents"][0])
|
combined_documents.extend(data["documents"][0])
|
||||||
combined_metadatas.extend(data["metadatas"][0])
|
combined_metadatas.extend(data["metadatas"][0])
|
||||||
|
combined_ids.extend(data["ids"][0])
|
||||||
|
|
||||||
# Create a list of tuples (distance, document, metadata)
|
# Create the output dictionary
|
||||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
result = {
|
||||||
|
"documents": [combined_documents],
|
||||||
|
"metadatas": [combined_metadatas],
|
||||||
|
"ids": [combined_ids],
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def merge_and_sort_query_results(
|
||||||
|
query_results: list[dict], k: int, reverse: bool = False
|
||||||
|
) -> dict:
|
||||||
|
# Initialize lists to store combined data
|
||||||
|
combined = []
|
||||||
|
seen_hashes = set() # To store unique document hashes
|
||||||
|
|
||||||
|
for data in query_results:
|
||||||
|
distances = data["distances"][0]
|
||||||
|
documents = data["documents"][0]
|
||||||
|
metadatas = data["metadatas"][0]
|
||||||
|
|
||||||
|
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||||
|
if isinstance(document, str):
|
||||||
|
doc_hash = hashlib.md5(
|
||||||
|
document.encode()
|
||||||
|
).hexdigest() # Compute a hash for uniqueness
|
||||||
|
|
||||||
|
if doc_hash not in seen_hashes:
|
||||||
|
seen_hashes.add(doc_hash)
|
||||||
|
combined.append((distance, document, metadata))
|
||||||
|
|
||||||
# Sort the list based on distances
|
# Sort the list based on distances
|
||||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||||
|
|
||||||
# We don't have anything :-(
|
# Slice to keep only the top k elements
|
||||||
if not combined:
|
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||||
sorted_distances = []
|
zip(*combined[:k]) if combined else ([], [], [])
|
||||||
sorted_documents = []
|
)
|
||||||
sorted_metadatas = []
|
|
||||||
else:
|
|
||||||
# Unzip the sorted list
|
|
||||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
|
||||||
|
|
||||||
# Slicing the lists to include only k elements
|
# Create and return the output dictionary
|
||||||
sorted_distances = list(sorted_distances)[:k]
|
return {
|
||||||
sorted_documents = list(sorted_documents)[:k]
|
"distances": [list(sorted_distances)],
|
||||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
"documents": [list(sorted_documents)],
|
||||||
|
"metadatas": [list(sorted_metadatas)],
|
||||||
# Create the output dictionary
|
|
||||||
result = {
|
|
||||||
"distances": [sorted_distances],
|
|
||||||
"documents": [sorted_documents],
|
|
||||||
"metadatas": [sorted_metadatas],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
|
||||||
|
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for collection_name in collection_names:
|
||||||
|
if collection_name:
|
||||||
|
try:
|
||||||
|
result = get_doc(collection_name=collection_name)
|
||||||
|
if result is not None:
|
||||||
|
results.append(result.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error when querying the collection: {e}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return merge_get_results(results)
|
||||||
|
|
||||||
|
|
||||||
def query_collection(
|
def query_collection(
|
||||||
|
|
@ -259,29 +311,35 @@ def get_embedding_function(
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query: embedding_function.encode(query).tolist()
|
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai"]:
|
||||||
func = lambda query: generate_embeddings(
|
func = lambda query, user=None: generate_embeddings(
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
text=query,
|
text=query,
|
||||||
url=url,
|
url=url,
|
||||||
key=key,
|
key=key,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_multiple(query, func):
|
def generate_multiple(query, user, func):
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for i in range(0, len(query), embedding_batch_size):
|
for i in range(0, len(query), embedding_batch_size):
|
||||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
embeddings.extend(
|
||||||
|
func(query[i : i + embedding_batch_size], user=user)
|
||||||
|
)
|
||||||
return embeddings
|
return embeddings
|
||||||
else:
|
else:
|
||||||
return func(query)
|
return func(query, user)
|
||||||
|
|
||||||
return lambda query: generate_multiple(query, func)
|
return lambda query, user=None: generate_multiple(query, user, func)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
def get_sources_from_files(
|
def get_sources_from_files(
|
||||||
|
request,
|
||||||
files,
|
files,
|
||||||
queries,
|
queries,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
|
|
@ -289,21 +347,81 @@ def get_sources_from_files(
|
||||||
reranking_function,
|
reranking_function,
|
||||||
r,
|
r,
|
||||||
hybrid_search,
|
hybrid_search,
|
||||||
|
full_context=False,
|
||||||
):
|
):
|
||||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
log.debug(
|
||||||
|
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||||
|
)
|
||||||
|
|
||||||
extracted_collections = []
|
extracted_collections = []
|
||||||
relevant_contexts = []
|
relevant_contexts = []
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.get("context") == "full":
|
|
||||||
|
context = None
|
||||||
|
if file.get("docs"):
|
||||||
|
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||||
|
context = {
|
||||||
|
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
||||||
|
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
||||||
|
}
|
||||||
|
elif file.get("context") == "full":
|
||||||
|
# Manual Full Mode Toggle
|
||||||
context = {
|
context = {
|
||||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||||
}
|
}
|
||||||
else:
|
elif (
|
||||||
context = None
|
file.get("type") != "web_search"
|
||||||
|
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
|
):
|
||||||
|
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
|
if file.get("type") == "collection":
|
||||||
|
file_ids = file.get("data", {}).get("file_ids", [])
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
for file_id in file_ids:
|
||||||
|
file_object = Files.get_file_by_id(file_id)
|
||||||
|
|
||||||
|
if file_object:
|
||||||
|
documents.append(file_object.data.get("content", ""))
|
||||||
|
metadatas.append(
|
||||||
|
{
|
||||||
|
"file_id": file_id,
|
||||||
|
"name": file_object.filename,
|
||||||
|
"source": file_object.filename,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"documents": [documents],
|
||||||
|
"metadatas": [metadatas],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif file.get("id"):
|
||||||
|
file_object = Files.get_file_by_id(file.get("id"))
|
||||||
|
if file_object:
|
||||||
|
context = {
|
||||||
|
"documents": [[file_object.data.get("content", "")]],
|
||||||
|
"metadatas": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file_id": file.get("id"),
|
||||||
|
"name": file_object.filename,
|
||||||
|
"source": file_object.filename,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
elif file.get("file").get("data"):
|
||||||
|
context = {
|
||||||
|
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||||
|
"metadatas": [
|
||||||
|
[file.get("file").get("data", {}).get("metadata", {})]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
collection_names = []
|
collection_names = []
|
||||||
if file.get("type") == "collection":
|
if file.get("type") == "collection":
|
||||||
if file.get("legacy"):
|
if file.get("legacy"):
|
||||||
|
|
@ -323,6 +441,13 @@ def get_sources_from_files(
|
||||||
log.debug(f"skipping {file} as it has already been extracted")
|
log.debug(f"skipping {file} as it has already been extracted")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if full_context:
|
||||||
|
try:
|
||||||
|
context = get_all_items_from_collections(collection_names)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
context = None
|
context = None
|
||||||
if file.get("type") == "text":
|
if file.get("type") == "text":
|
||||||
|
|
@ -359,6 +484,7 @@ def get_sources_from_files(
|
||||||
if context:
|
if context:
|
||||||
if "data" in file:
|
if "data" in file:
|
||||||
del file["data"]
|
del file["data"]
|
||||||
|
|
||||||
relevant_contexts.append({**context, "file": file})
|
relevant_contexts.append({**context, "file": file})
|
||||||
|
|
||||||
sources = []
|
sources = []
|
||||||
|
|
@ -423,7 +549,11 @@ def get_model_path(model: str, update_model: bool = False):
|
||||||
|
|
||||||
|
|
||||||
def generate_openai_batch_embeddings(
|
def generate_openai_batch_embeddings(
|
||||||
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str = "https://api.openai.com/v1",
|
||||||
|
key: str = "",
|
||||||
|
user: UserModel = None,
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
|
|
@ -431,6 +561,16 @@ def generate_openai_batch_embeddings(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
json={"input": texts, "model": model},
|
json={"input": texts, "model": model},
|
||||||
)
|
)
|
||||||
|
|
@ -441,12 +581,12 @@ def generate_openai_batch_embeddings(
|
||||||
else:
|
else:
|
||||||
raise "Something went wrong :/"
|
raise "Something went wrong :/"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_ollama_batch_embeddings(
|
def generate_ollama_batch_embeddings(
|
||||||
model: str, texts: list[str], url: str, key: str = ""
|
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
|
|
@ -454,6 +594,16 @@ def generate_ollama_batch_embeddings(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
json={"input": texts, "model": model},
|
json={"input": texts, "model": model},
|
||||||
)
|
)
|
||||||
|
|
@ -465,29 +615,36 @@ def generate_ollama_batch_embeddings(
|
||||||
else:
|
else:
|
||||||
raise "Something went wrong :/"
|
raise "Something went wrong :/"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||||
url = kwargs.get("url", "")
|
url = kwargs.get("url", "")
|
||||||
key = kwargs.get("key", "")
|
key = kwargs.get("key", "")
|
||||||
|
user = kwargs.get("user")
|
||||||
|
|
||||||
if engine == "ollama":
|
if engine == "ollama":
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
**{"model": model, "texts": text, "url": url, "key": key}
|
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
**{"model": model, "texts": [text], "url": url, "key": key}
|
**{
|
||||||
|
"model": model,
|
||||||
|
"texts": [text],
|
||||||
|
"url": url,
|
||||||
|
"key": key,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
elif engine == "openai":
|
elif engine == "openai":
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||||
else:
|
else:
|
||||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||||
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector":
|
||||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||||
|
|
||||||
VECTOR_DB_CLIENT = PgvectorClient()
|
VECTOR_DB_CLIENT = PgvectorClient()
|
||||||
|
elif VECTOR_DB == "elasticsearch":
|
||||||
|
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
|
||||||
|
|
||||||
|
VECTOR_DB_CLIENT = ElasticsearchClient()
|
||||||
else:
|
else:
|
||||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||||
|
|
||||||
|
|
|
||||||
8
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file → Executable file
|
|
@ -1,4 +1,5 @@
|
||||||
import chromadb
|
import chromadb
|
||||||
|
import logging
|
||||||
from chromadb import Settings
|
from chromadb import Settings
|
||||||
from chromadb.utils.batch_utils import create_batches
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
|
|
@ -16,6 +17,10 @@ from open_webui.config import (
|
||||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ChromaClient:
|
class ChromaClient:
|
||||||
|
|
@ -102,8 +107,7 @@ class ChromaClient:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except:
|
||||||
print(e)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
|
|
||||||
274
backend/open_webui/retrieval/vector/dbs/elasticsearch.py
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
from elasticsearch import Elasticsearch, BadRequestError
|
||||||
|
from typing import Optional
|
||||||
|
import ssl
|
||||||
|
from elasticsearch.helpers import bulk, scan
|
||||||
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||||
|
from open_webui.config import (
|
||||||
|
ELASTICSEARCH_URL,
|
||||||
|
ELASTICSEARCH_CA_CERTS,
|
||||||
|
ELASTICSEARCH_API_KEY,
|
||||||
|
ELASTICSEARCH_USERNAME,
|
||||||
|
ELASTICSEARCH_PASSWORD,
|
||||||
|
ELASTICSEARCH_CLOUD_ID,
|
||||||
|
SSL_ASSERT_FINGERPRINT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchClient:
|
||||||
|
"""
|
||||||
|
Important:
|
||||||
|
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||||
|
an index for each file but store it as a text field, while seperating to different index
|
||||||
|
baesd on the embedding length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.index_prefix = "open_webui_collections"
|
||||||
|
self.client = Elasticsearch(
|
||||||
|
hosts=[ELASTICSEARCH_URL],
|
||||||
|
ca_certs=ELASTICSEARCH_CA_CERTS,
|
||||||
|
api_key=ELASTICSEARCH_API_KEY,
|
||||||
|
cloud_id=ELASTICSEARCH_CLOUD_ID,
|
||||||
|
basic_auth=(
|
||||||
|
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
|
||||||
|
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _get_index_name(self, dimension: int) -> str:
|
||||||
|
return f"{self.index_prefix}_d{str(dimension)}"
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
ids = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _result_to_get_result(self, result) -> GetResult:
|
||||||
|
if not result["hits"]["hits"]:
|
||||||
|
return None
|
||||||
|
ids = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _result_to_search_result(self, result) -> SearchResult:
|
||||||
|
ids = []
|
||||||
|
distances = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
distances.append(hit["_score"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return SearchResult(
|
||||||
|
ids=[ids],
|
||||||
|
distances=[distances],
|
||||||
|
documents=[documents],
|
||||||
|
metadatas=[metadatas],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _create_index(self, dimension: int):
|
||||||
|
body = {
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"collection": {"type": "keyword"},
|
||||||
|
"id": {"type": "keyword"},
|
||||||
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": dimension, # Adjust based on your vector dimensions
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine",
|
||||||
|
},
|
||||||
|
"text": {"type": "text"},
|
||||||
|
"metadata": {"type": "object"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.client.indices.create(index=self._get_index_name(dimension), body=body)
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
|
||||||
|
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||||
|
for i in range(0, len(items), batch_size):
|
||||||
|
yield items[i : min(i + batch_size, len(items))]
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def has_collection(self, collection_name) -> bool:
|
||||||
|
query_body = {"query": {"bool": {"filter": []}}}
|
||||||
|
query_body["query"]["bool"]["filter"].append(
|
||||||
|
{"term": {"collection": collection_name}}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
||||||
|
|
||||||
|
return result.body["count"] > 0
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# @TODO: Make this delete a collection and not an index
|
||||||
|
def delete_colleciton(self, collection_name: str):
|
||||||
|
# TODO: fix this to include the dimension or a * prefix
|
||||||
|
# delete_collection here means delete a bunch of documents for an index.
|
||||||
|
# We are simply adapting to the norms of the other DBs.
|
||||||
|
self.client.indices.delete(index=self._get_collection_name(collection_name))
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
query = {
|
||||||
|
"size": limit,
|
||||||
|
"_source": ["text", "metadata"],
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {
|
||||||
|
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
||||||
|
},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||||
|
"params": {
|
||||||
|
"vector": vectors[0]
|
||||||
|
}, # Assuming single query vector
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.client.search(
|
||||||
|
index=self._get_index_name(len(vectors[0])), body=query
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_search_result(result)
|
||||||
|
|
||||||
|
# Status: only tested halfwat
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
query_body = {
|
||||||
|
"query": {"bool": {"filter": []}},
|
||||||
|
"_source": ["text", "metadata"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, value in filter.items():
|
||||||
|
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||||
|
query_body["query"]["bool"]["filter"].append(
|
||||||
|
{"term": {"collection": collection_name}}
|
||||||
|
)
|
||||||
|
size = limit if limit else 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.search(
|
||||||
|
index=f"{self.index_prefix}*",
|
||||||
|
body=query_body,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_get_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def _has_index(self, dimension: int):
|
||||||
|
return self.client.indices.exists(
|
||||||
|
index=self._get_index_name(dimension=dimension)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_or_create_index(self, dimension: int):
|
||||||
|
if not self._has_index(dimension=dimension):
|
||||||
|
self._create_index(dimension=dimension)
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
# Get all the items in the collection.
|
||||||
|
query = {
|
||||||
|
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
||||||
|
"_source": ["text", "metadata"],
|
||||||
|
}
|
||||||
|
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||||
|
|
||||||
|
return self._scan_result_to_get_result(results)
|
||||||
|
|
||||||
|
# Status: works
|
||||||
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||||
|
self._create_index(dimension=len(items[0]["vector"]))
|
||||||
|
|
||||||
|
for batch in self._create_batches(items):
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||||
|
"_id": item["id"],
|
||||||
|
"_source": {
|
||||||
|
"collection": collection_name,
|
||||||
|
"vector": item["vector"],
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for item in batch
|
||||||
|
]
|
||||||
|
bulk(self.client, actions)
|
||||||
|
|
||||||
|
# Status: should work
|
||||||
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||||
|
self._create_index(collection_name, dimension=len(items[0]["vector"]))
|
||||||
|
|
||||||
|
for batch in self._create_batches(items):
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||||
|
"_id": item["id"],
|
||||||
|
"_source": {
|
||||||
|
"vector": item["vector"],
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for item in batch
|
||||||
|
]
|
||||||
|
self.client.bulk(actions)
|
||||||
|
|
||||||
|
# TODO: This currently deletes by * which is not always supported in ElasticSearch.
|
||||||
|
# Need to read a bit before changing. Also, need to delete from a specific collection
|
||||||
|
def delete(self, collection_name: str, ids: list[str]):
|
||||||
|
# Assuming ID is unique across collections and indexes
|
||||||
|
actions = [
|
||||||
|
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
|
||||||
|
]
|
||||||
|
self.client.bulk(body=actions)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||||
|
for index in indices:
|
||||||
|
self.client.indices.delete(index=index)
|
||||||
|
|
@ -1,20 +1,28 @@
|
||||||
from pymilvus import MilvusClient as Client
|
from pymilvus import MilvusClient as Client
|
||||||
from pymilvus import FieldSchema, DataType
|
from pymilvus import FieldSchema, DataType
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
MILVUS_URI,
|
MILVUS_URI,
|
||||||
MILVUS_DB,
|
MILVUS_DB,
|
||||||
|
MILVUS_TOKEN,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class MilvusClient:
|
class MilvusClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = "open_webui"
|
self.collection_prefix = "open_webui"
|
||||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
|
if MILVUS_TOKEN is None:
|
||||||
|
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||||
|
else:
|
||||||
|
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
|
||||||
|
|
||||||
def _result_to_get_result(self, result) -> GetResult:
|
def _result_to_get_result(self, result) -> GetResult:
|
||||||
ids = []
|
ids = []
|
||||||
|
|
@ -164,7 +172,7 @@ class MilvusClient:
|
||||||
try:
|
try:
|
||||||
# Loop until there are no more items to fetch or the desired limit is reached
|
# Loop until there are no more items to fetch or the desired limit is reached
|
||||||
while remaining > 0:
|
while remaining > 0:
|
||||||
print("remaining", remaining)
|
log.info(f"remaining: {remaining}")
|
||||||
current_fetch = min(
|
current_fetch = min(
|
||||||
max_limit, remaining
|
max_limit, remaining
|
||||||
) # Determine how many items to fetch in this iteration
|
) # Determine how many items to fetch in this iteration
|
||||||
|
|
@ -191,10 +199,12 @@ class MilvusClient:
|
||||||
if results_count < current_fetch:
|
if results_count < current_fetch:
|
||||||
break
|
break
|
||||||
|
|
||||||
print(all_results)
|
log.debug(all_results)
|
||||||
return self._result_to_get_result([all_results])
|
return self._result_to_get_result([all_results])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(
|
||||||
|
f"Error querying collection {collection_name} with limit {limit}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class OpenSearchClient:
|
||||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_index(self, index_name: str, dimension: int):
|
def _create_index(self, collection_name: str, dimension: int):
|
||||||
body = {
|
body = {
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -72,24 +72,28 @@ class OpenSearchClient:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
self.client.indices.create(
|
||||||
|
index=f"{self.index_prefix}_{collection_name}", body=body
|
||||||
|
)
|
||||||
|
|
||||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||||
for i in range(0, len(items), batch_size):
|
for i in range(0, len(items), batch_size):
|
||||||
yield items[i : i + batch_size]
|
yield items[i : i + batch_size]
|
||||||
|
|
||||||
def has_collection(self, index_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
# has_collection here means has index.
|
# has_collection here means has index.
|
||||||
# We are simply adapting to the norms of the other DBs.
|
# We are simply adapting to the norms of the other DBs.
|
||||||
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
return self.client.indices.exists(
|
||||||
|
index=f"{self.index_prefix}_{collection_name}"
|
||||||
|
)
|
||||||
|
|
||||||
def delete_colleciton(self, index_name: str):
|
def delete_colleciton(self, collection_name: str):
|
||||||
# delete_collection here means delete index.
|
# delete_collection here means delete index.
|
||||||
# We are simply adapting to the norms of the other DBs.
|
# We are simply adapting to the norms of the other DBs.
|
||||||
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, index_name: str, vectors: list[list[float]], limit: int
|
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||||
) -> Optional[SearchResult]:
|
) -> Optional[SearchResult]:
|
||||||
query = {
|
query = {
|
||||||
"size": limit,
|
"size": limit,
|
||||||
|
|
@ -108,26 +112,55 @@ class OpenSearchClient:
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.client.search(
|
result = self.client.search(
|
||||||
index=f"{self.index_prefix}_{index_name}", body=query
|
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._result_to_search_result(result)
|
return self._result_to_search_result(result)
|
||||||
|
|
||||||
def get_or_create_index(self, index_name: str, dimension: int):
|
def query(
|
||||||
if not self.has_index(index_name):
|
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||||
self._create_index(index_name, dimension)
|
) -> Optional[GetResult]:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
def get(self, index_name: str) -> Optional[GetResult]:
|
query_body = {
|
||||||
|
"query": {"bool": {"filter": []}},
|
||||||
|
"_source": ["text", "metadata"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, value in filter.items():
|
||||||
|
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||||
|
|
||||||
|
size = limit if limit else 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.search(
|
||||||
|
index=f"{self.index_prefix}_{collection_name}",
|
||||||
|
body=query_body,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_get_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
|
||||||
|
if not self.has_index(collection_name):
|
||||||
|
self._create_index(collection_name, dimension)
|
||||||
|
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||||
|
|
||||||
result = self.client.search(
|
result = self.client.search(
|
||||||
index=f"{self.index_prefix}_{index_name}", body=query
|
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(result)
|
return self._result_to_get_result(result)
|
||||||
|
|
||||||
def insert(self, index_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
if not self.has_index(index_name):
|
self._create_index_if_not_exists(
|
||||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||||
|
)
|
||||||
|
|
||||||
for batch in self._create_batches(items):
|
for batch in self._create_batches(items):
|
||||||
actions = [
|
actions = [
|
||||||
|
|
@ -145,15 +178,17 @@ class OpenSearchClient:
|
||||||
]
|
]
|
||||||
self.client.bulk(actions)
|
self.client.bulk(actions)
|
||||||
|
|
||||||
def upsert(self, index_name: str, items: list[VectorItem]):
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
if not self.has_index(index_name):
|
self._create_index_if_not_exists(
|
||||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||||
|
)
|
||||||
|
|
||||||
for batch in self._create_batches(items):
|
for batch in self._create_batches(items):
|
||||||
actions = [
|
actions = [
|
||||||
{
|
{
|
||||||
"index": {
|
"index": {
|
||||||
"_id": item["id"],
|
"_id": item["id"],
|
||||||
|
"_index": f"{self.index_prefix}_{collection_name}",
|
||||||
"_source": {
|
"_source": {
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
|
|
@ -165,9 +200,9 @@ class OpenSearchClient:
|
||||||
]
|
]
|
||||||
self.client.bulk(actions)
|
self.client.bulk(actions)
|
||||||
|
|
||||||
def delete(self, index_name: str, ids: list[str]):
|
def delete(self, collection_name: str, ids: list[str]):
|
||||||
actions = [
|
actions = [
|
||||||
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
|
||||||
for id in ids
|
for id in ids
|
||||||
]
|
]
|
||||||
self.client.bulk(body=actions)
|
self.client.bulk(body=actions)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
|
import logging
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
cast,
|
cast,
|
||||||
column,
|
column,
|
||||||
|
|
@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(Base):
|
class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
@ -82,10 +88,10 @@ class PgvectorClient:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
print("Initialization complete.")
|
log.info("Initialization complete.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
print(f"Error during initialization: {e}")
|
log.exception(f"Error during initialization: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def check_vector_length(self) -> None:
|
def check_vector_length(self) -> None:
|
||||||
|
|
@ -150,12 +156,12 @@ class PgvectorClient:
|
||||||
new_items.append(new_chunk)
|
new_items.append(new_chunk)
|
||||||
self.session.bulk_save_objects(new_items)
|
self.session.bulk_save_objects(new_items)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
print(
|
log.info(
|
||||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
print(f"Error during insert: {e}")
|
log.exception(f"Error during insert: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
|
@ -184,10 +190,12 @@ class PgvectorClient:
|
||||||
)
|
)
|
||||||
self.session.add(new_chunk)
|
self.session.add(new_chunk)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
log.info(
|
||||||
|
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
print(f"Error during upsert: {e}")
|
log.exception(f"Error during upsert: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
|
|
@ -278,7 +286,7 @@ class PgvectorClient:
|
||||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during search: {e}")
|
log.exception(f"Error during search: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|
@ -310,7 +318,7 @@ class PgvectorClient:
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during query: {e}")
|
log.exception(f"Error during query: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
|
@ -334,7 +342,7 @@ class PgvectorClient:
|
||||||
|
|
||||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during get: {e}")
|
log.exception(f"Error during get: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
|
|
@ -356,22 +364,22 @@ class PgvectorClient:
|
||||||
)
|
)
|
||||||
deleted = query.delete(synchronize_session=False)
|
deleted = query.delete(synchronize_session=False)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
print(f"Error during delete: {e}")
|
log.exception(f"Error during delete: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
deleted = self.session.query(DocumentChunk).delete()
|
deleted = self.session.query(DocumentChunk).delete()
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
print(
|
log.info(
|
||||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
print(f"Error during reset: {e}")
|
log.exception(f"Error during reset: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
|
@ -387,9 +395,9 @@ class PgvectorClient:
|
||||||
)
|
)
|
||||||
return exists
|
return exists
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error checking collection existence: {e}")
|
log.exception(f"Error checking collection existence: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
self.delete(collection_name)
|
self.delete(collection_name)
|
||||||
print(f"Collection '{collection_name}' deleted.")
|
log.info(f"Collection '{collection_name}' deleted.")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
from qdrant_client import QdrantClient as Qclient
|
from qdrant_client import QdrantClient as Qclient
|
||||||
from qdrant_client.http.models import PointStruct
|
from qdrant_client.http.models import PointStruct
|
||||||
|
|
@ -6,9 +7,13 @@ from qdrant_client.models import models
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
NO_LIMIT = 999999999
|
NO_LIMIT = 999999999
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class QdrantClient:
|
class QdrantClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -49,7 +54,7 @@ class QdrantClient:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||||
|
|
||||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||||
if not self.has_collection(collection_name=collection_name):
|
if not self.has_collection(collection_name=collection_name):
|
||||||
|
|
@ -120,7 +125,7 @@ class QdrantClient:
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points.points)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
|
|
||||||
65
backend/open_webui/retrieval/web/bocha.py
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_response(response):
|
||||||
|
result = {}
|
||||||
|
if "data" in response:
|
||||||
|
data = response["data"]
|
||||||
|
if "webPages" in data:
|
||||||
|
webPages = data["webPages"]
|
||||||
|
if "value" in webPages:
|
||||||
|
result["webpage"] = [
|
||||||
|
{
|
||||||
|
"id": item.get("id", ""),
|
||||||
|
"name": item.get("name", ""),
|
||||||
|
"url": item.get("url", ""),
|
||||||
|
"snippet": item.get("snippet", ""),
|
||||||
|
"summary": item.get("summary", ""),
|
||||||
|
"siteName": item.get("siteName", ""),
|
||||||
|
"siteIcon": item.get("siteIcon", ""),
|
||||||
|
"datePublished": item.get("datePublished", "")
|
||||||
|
or item.get("dateLastCrawled", ""),
|
||||||
|
}
|
||||||
|
for item in webPages["value"]
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def search_bocha(
|
||||||
|
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): A Bocha Search API key
|
||||||
|
query (str): The query to search for
|
||||||
|
"""
|
||||||
|
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
payload = json.dumps(
|
||||||
|
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = _parse_response(response.json())
|
||||||
|
print(results)
|
||||||
|
if filter_list:
|
||||||
|
results = get_filtered_results(results, filter_list)
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||||
|
)
|
||||||
|
for result in results.get("webpage", [])[:count]
|
||||||
|
]
|
||||||
|
|
@ -32,19 +32,15 @@ def search_duckduckgo(
|
||||||
# Convert the search results into a list
|
# Convert the search results into a list
|
||||||
search_results = [r for r in ddgs_gen]
|
search_results = [r for r in ddgs_gen]
|
||||||
|
|
||||||
# Create an empty list to store the SearchResult objects
|
if filter_list:
|
||||||
results = []
|
search_results = get_filtered_results(search_results, filter_list)
|
||||||
# Iterate over each search result
|
|
||||||
for result in search_results:
|
# Return the list of search results
|
||||||
# Create a SearchResult object and append it to the results list
|
return [
|
||||||
results.append(
|
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["href"],
|
link=result["href"],
|
||||||
title=result.get("title"),
|
title=result.get("title"),
|
||||||
snippet=result.get("body"),
|
snippet=result.get("body"),
|
||||||
)
|
)
|
||||||
)
|
for result in search_results
|
||||||
if filter_list:
|
]
|
||||||
results = get_filtered_results(results, filter_list)
|
|
||||||
# Return the list of search results
|
|
||||||
return results
|
|
||||||
|
|
|
||||||
76
backend/open_webui/retrieval/web/exa.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.retrieval.web.main import SearchResult
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
EXA_API_BASE = "https://api.exa.ai"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExaResult:
|
||||||
|
url: str
|
||||||
|
title: str
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def search_exa(
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
count: int,
|
||||||
|
filter_list: Optional[list[str]] = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search using Exa Search API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): A Exa Search API key
|
||||||
|
query (str): The query to search for
|
||||||
|
count (int): Number of results to return
|
||||||
|
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||||
|
"""
|
||||||
|
log.info(f"Searching with Exa for query: {query}")
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"numResults": count or 5,
|
||||||
|
"includeDomains": filter_list,
|
||||||
|
"contents": {"text": True, "highlights": True},
|
||||||
|
"type": "auto", # Use the auto search type (keyword or neural)
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{EXA_API_BASE}/search", headers=headers, json=payload
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for result in data["results"]:
|
||||||
|
results.append(
|
||||||
|
ExaResult(
|
||||||
|
url=result["url"],
|
||||||
|
title=result["title"],
|
||||||
|
text=result["text"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Found {len(results)} results")
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
link=result.url,
|
||||||
|
title=result.title,
|
||||||
|
snippet=result.text,
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error searching Exa: {e}")
|
||||||
|
return []
|
||||||
|
|
@ -17,34 +17,53 @@ def search_google_pse(
|
||||||
filter_list: Optional[list[str]] = None,
|
filter_list: Optional[list[str]] = None,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||||
|
Handles pagination for counts greater than 10.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key (str): A Programmable Search Engine API key
|
api_key (str): A Programmable Search Engine API key
|
||||||
search_engine_id (str): A Programmable Search Engine ID
|
search_engine_id (str): A Programmable Search Engine ID
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
|
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
|
||||||
|
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[SearchResult]: A list of SearchResult objects.
|
||||||
"""
|
"""
|
||||||
url = "https://www.googleapis.com/customsearch/v1"
|
url = "https://www.googleapis.com/customsearch/v1"
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
all_results = []
|
||||||
|
start_index = 1 # Google PSE start parameter is 1-based
|
||||||
|
|
||||||
|
while count > 0:
|
||||||
|
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||||
params = {
|
params = {
|
||||||
"cx": search_engine_id,
|
"cx": search_engine_id,
|
||||||
"q": query,
|
"q": query,
|
||||||
"key": api_key,
|
"key": api_key,
|
||||||
"num": count,
|
"num": num_results_this_page,
|
||||||
|
"start": start_index,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("GET", url, headers=headers, params=params)
|
response = requests.request("GET", url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
results = json_response.get("items", [])
|
results = json_response.get("items", [])
|
||||||
|
if results: # check if results are returned. If not, no more pages to fetch.
|
||||||
|
all_results.extend(results)
|
||||||
|
count -= len(
|
||||||
|
results
|
||||||
|
) # Decrement count by the number of results fetched in this page.
|
||||||
|
start_index += 10 # Increment start index for the next page
|
||||||
|
else:
|
||||||
|
break # No more results from Google PSE, break the loop
|
||||||
|
|
||||||
if filter_list:
|
if filter_list:
|
||||||
results = get_filtered_results(results, filter_list)
|
all_results = get_filtered_results(all_results, filter_list)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["link"],
|
link=result["link"],
|
||||||
title=result.get("title"),
|
title=result.get("title"),
|
||||||
snippet=result.get("snippet"),
|
snippet=result.get("snippet"),
|
||||||
)
|
)
|
||||||
for result in results
|
for result in all_results
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||||
list[SearchResult]: A list of search results
|
list[SearchResult]: A list of search results
|
||||||
"""
|
"""
|
||||||
jina_search_endpoint = "https://s.jina.ai/"
|
jina_search_endpoint = "https://s.jina.ai/"
|
||||||
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
||||||
url = str(URL(jina_search_endpoint + query))
|
headers = {
|
||||||
response = requests.get(url, headers=headers)
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": api_key,
|
||||||
|
"X-Retain-Images": "none",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||||
|
|
||||||
|
url = str(URL(jina_search_endpoint))
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in data["data"][:count]:
|
for result in data["data"]:
|
||||||
results.append(
|
results.append(
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"],
|
link=result["url"],
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import validators
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
url = result.get("url") or result.get("link", "")
|
url = result.get("url") or result.get("link", "")
|
||||||
|
if not validators.url(url):
|
||||||
|
continue
|
||||||
domain = urlparse(url).netloc
|
domain = urlparse(url).netloc
|
||||||
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
|
|
|
||||||
87
backend/open_webui/retrieval/web/perplexity.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def search_perplexity(
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
count: int,
|
||||||
|
filter_list: Optional[list[str]] = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): A Perplexity API key
|
||||||
|
query (str): The query to search for
|
||||||
|
count (int): Maximum number of results to return
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle PersistentConfig object
|
||||||
|
if hasattr(api_key, "__str__"):
|
||||||
|
api_key = str(api_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = "https://api.perplexity.ai/chat/completions"
|
||||||
|
|
||||||
|
# Create payload for the API call
|
||||||
|
payload = {
|
||||||
|
"model": "sonar",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a search assistant. Provide factual information with citations.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": query},
|
||||||
|
],
|
||||||
|
"temperature": 0.2, # Lower temperature for more factual responses
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make the API request
|
||||||
|
response = requests.request("POST", url, json=payload, headers=headers)
|
||||||
|
|
||||||
|
# Parse the JSON response
|
||||||
|
json_response = response.json()
|
||||||
|
|
||||||
|
# Extract citations from the response
|
||||||
|
citations = json_response.get("citations", [])
|
||||||
|
|
||||||
|
# Create search results from citations
|
||||||
|
results = []
|
||||||
|
for i, citation in enumerate(citations[:count]):
|
||||||
|
# Extract content from the response to use as snippet
|
||||||
|
content = ""
|
||||||
|
if "choices" in json_response and json_response["choices"]:
|
||||||
|
if i == 0:
|
||||||
|
content = json_response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if filter_list:
|
||||||
|
|
||||||
|
results = get_filtered_results(results, filter_list)
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||||
|
)
|
||||||
|
for result in results[:count]
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error searching with Perplexity API: {e}")
|
||||||
|
return []
|
||||||
48
backend/open_webui/retrieval/web/serpapi.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def search_serpapi(
|
||||||
|
api_key: str,
|
||||||
|
engine: str,
|
||||||
|
query: str,
|
||||||
|
count: int,
|
||||||
|
filter_list: Optional[list[str]] = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): A serpapi.com API key
|
||||||
|
query (str): The query to search for
|
||||||
|
"""
|
||||||
|
url = "https://serpapi.com/search"
|
||||||
|
|
||||||
|
engine = engine or "google"
|
||||||
|
|
||||||
|
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||||
|
|
||||||
|
url = f"{url}?{urlencode(payload)}"
|
||||||
|
response = requests.request("GET", url)
|
||||||
|
|
||||||
|
json_response = response.json()
|
||||||
|
log.info(f"results from serpapi search: {json_response}")
|
||||||
|
|
||||||
|
results = sorted(
|
||||||
|
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||||
|
)
|
||||||
|
if filter_list:
|
||||||
|
results = get_filtered_results(results, filter_list)
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||||
|
)
|
||||||
|
for result in results[:count]
|
||||||
|
]
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult
|
from open_webui.retrieval.web.main import SearchResult
|
||||||
|
|
@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
def search_tavily(
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
count: int,
|
||||||
|
filter_list: Optional[list[str]] = None,
|
||||||
|
# **kwargs,
|
||||||
|
) -> list[SearchResult]:
|
||||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -20,7 +27,6 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
url = "https://api.tavily.com/search"
|
url = "https://api.tavily.com/search"
|
||||||
data = {"query": query, "api_key": api_key}
|
data = {"query": query, "api_key": api_key}
|
||||||
|
|
||||||
response = requests.post(url, json=data)
|
response = requests.post(url, json=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,38 @@
|
||||||
import socket
|
import asyncio
|
||||||
import urllib.parse
|
|
||||||
import validators
|
|
||||||
from typing import Union, Sequence, Iterator
|
|
||||||
|
|
||||||
from langchain_community.document_loaders import (
|
|
||||||
WebBaseLoader,
|
|
||||||
)
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
|
||||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, time, timedelta
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
Literal,
|
||||||
|
)
|
||||||
|
import aiohttp
|
||||||
|
import certifi
|
||||||
|
import validators
|
||||||
|
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||||
|
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||||
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.config import (
|
||||||
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
|
PLAYWRIGHT_WS_URI,
|
||||||
|
RAG_WEB_LOADER_ENGINE,
|
||||||
|
FIRECRAWL_API_BASE_URL,
|
||||||
|
FIRECRAWL_API_KEY,
|
||||||
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
@ -43,6 +62,17 @@ def validate_url(url: Union[str, Sequence[str]]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||||
|
valid_urls = []
|
||||||
|
for u in url:
|
||||||
|
try:
|
||||||
|
if validate_url(u):
|
||||||
|
valid_urls.append(u)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return valid_urls
|
||||||
|
|
||||||
|
|
||||||
def resolve_hostname(hostname):
|
def resolve_hostname(hostname):
|
||||||
# Get address information
|
# Get address information
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
|
@ -54,9 +84,381 @@ def resolve_hostname(hostname):
|
||||||
return ipv4_addresses, ipv6_addresses
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
|
def extract_metadata(soup, url):
|
||||||
|
metadata = {"source": url}
|
||||||
|
if title := soup.find("title"):
|
||||||
|
metadata["title"] = title.get_text()
|
||||||
|
if description := soup.find("meta", attrs={"name": "description"}):
|
||||||
|
metadata["description"] = description.get("content", "No description found.")
|
||||||
|
if html := soup.find("html"):
|
||||||
|
metadata["language"] = html.get("lang", "No language found.")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def verify_ssl_cert(url: str) -> bool:
|
||||||
|
"""Verify SSL certificate for the given URL."""
|
||||||
|
if not url.startswith("https://"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
hostname = url.split("://")[-1].split("/")[0]
|
||||||
|
context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
|
||||||
|
s.connect((hostname, 443))
|
||||||
|
return True
|
||||||
|
except ssl.SSLError:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class SafeFireCrawlLoader(BaseLoader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
web_paths,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
trust_env: bool = False,
|
||||||
|
requests_per_second: Optional[float] = None,
|
||||||
|
continue_on_failure: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
mode: Literal["crawl", "scrape", "map"] = "crawl",
|
||||||
|
proxy: Optional[Dict[str, str]] = None,
|
||||||
|
params: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Concurrent document loader for FireCrawl operations.
|
||||||
|
|
||||||
|
Executes multiple FireCrawlLoader instances concurrently using thread pooling
|
||||||
|
to improve bulk processing efficiency.
|
||||||
|
Args:
|
||||||
|
web_paths: List of URLs/paths to process.
|
||||||
|
verify_ssl: If True, verify SSL certificates.
|
||||||
|
trust_env: If True, use proxy settings from environment variables.
|
||||||
|
requests_per_second: Number of requests per second to limit to.
|
||||||
|
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||||
|
api_key: API key for FireCrawl service. Defaults to None
|
||||||
|
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||||
|
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||||
|
mode: Operation mode selection:
|
||||||
|
- 'crawl': Website crawling mode (default)
|
||||||
|
- 'scrape': Direct page scraping
|
||||||
|
- 'map': Site map generation
|
||||||
|
proxy: Proxy override settings for the FireCrawl API.
|
||||||
|
params: The parameters to pass to the Firecrawl API.
|
||||||
|
Examples include crawlerOptions.
|
||||||
|
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||||
|
"""
|
||||||
|
proxy_server = proxy.get("server") if proxy else None
|
||||||
|
if trust_env and not proxy_server:
|
||||||
|
env_proxies = urllib.request.getproxies()
|
||||||
|
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||||
|
if env_proxy_server:
|
||||||
|
if proxy:
|
||||||
|
proxy["server"] = env_proxy_server
|
||||||
|
else:
|
||||||
|
proxy = {"server": env_proxy_server}
|
||||||
|
self.web_paths = web_paths
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
self.requests_per_second = requests_per_second
|
||||||
|
self.last_request_time = None
|
||||||
|
self.trust_env = trust_env
|
||||||
|
self.continue_on_failure = continue_on_failure
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_url = api_url
|
||||||
|
self.mode = mode
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
|
"""Load documents concurrently using FireCrawl."""
|
||||||
|
for url in self.web_paths:
|
||||||
|
try:
|
||||||
|
self._safe_process_url_sync(url)
|
||||||
|
loader = FireCrawlLoader(
|
||||||
|
url=url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_url=self.api_url,
|
||||||
|
mode=self.mode,
|
||||||
|
params=self.params,
|
||||||
|
)
|
||||||
|
yield from loader.lazy_load()
|
||||||
|
except Exception as e:
|
||||||
|
if self.continue_on_failure:
|
||||||
|
log.exception(e, "Error loading %s", url)
|
||||||
|
continue
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def alazy_load(self):
|
||||||
|
"""Async version of lazy_load."""
|
||||||
|
for url in self.web_paths:
|
||||||
|
try:
|
||||||
|
await self._safe_process_url(url)
|
||||||
|
loader = FireCrawlLoader(
|
||||||
|
url=url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_url=self.api_url,
|
||||||
|
mode=self.mode,
|
||||||
|
params=self.params,
|
||||||
|
)
|
||||||
|
async for document in loader.alazy_load():
|
||||||
|
yield document
|
||||||
|
except Exception as e:
|
||||||
|
if self.continue_on_failure:
|
||||||
|
log.exception(e, "Error loading %s", url)
|
||||||
|
continue
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _verify_ssl_cert(self, url: str) -> bool:
|
||||||
|
return verify_ssl_cert(url)
|
||||||
|
|
||||||
|
async def _wait_for_rate_limit(self):
|
||||||
|
"""Wait to respect the rate limit if specified."""
|
||||||
|
if self.requests_per_second and self.last_request_time:
|
||||||
|
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||||
|
time_since_last = datetime.now() - self.last_request_time
|
||||||
|
if time_since_last < min_interval:
|
||||||
|
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||||
|
self.last_request_time = datetime.now()
|
||||||
|
|
||||||
|
def _sync_wait_for_rate_limit(self):
|
||||||
|
"""Synchronous version of rate limit wait."""
|
||||||
|
if self.requests_per_second and self.last_request_time:
|
||||||
|
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||||
|
time_since_last = datetime.now() - self.last_request_time
|
||||||
|
if time_since_last < min_interval:
|
||||||
|
time.sleep((min_interval - time_since_last).total_seconds())
|
||||||
|
self.last_request_time = datetime.now()
|
||||||
|
|
||||||
|
async def _safe_process_url(self, url: str) -> bool:
|
||||||
|
"""Perform safety checks before processing a URL."""
|
||||||
|
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||||
|
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||||
|
await self._wait_for_rate_limit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _safe_process_url_sync(self, url: str) -> bool:
|
||||||
|
"""Synchronous version of safety checks."""
|
||||||
|
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||||
|
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||||
|
self._sync_wait_for_rate_limit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||||
|
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
web_paths (List[str]): List of URLs to load.
|
||||||
|
verify_ssl (bool): If True, verify SSL certificates.
|
||||||
|
trust_env (bool): If True, use proxy settings from environment variables.
|
||||||
|
requests_per_second (Optional[float]): Number of requests per second to limit to.
|
||||||
|
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||||
|
headless (bool): If True, the browser will run in headless mode.
|
||||||
|
proxy (dict): Proxy override settings for the Playwright session.
|
||||||
|
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
web_paths: List[str],
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
trust_env: bool = False,
|
||||||
|
requests_per_second: Optional[float] = None,
|
||||||
|
continue_on_failure: bool = True,
|
||||||
|
headless: bool = True,
|
||||||
|
remove_selectors: Optional[List[str]] = None,
|
||||||
|
proxy: Optional[Dict[str, str]] = None,
|
||||||
|
playwright_ws_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize with additional safety parameters and remote browser support."""
|
||||||
|
|
||||||
|
proxy_server = proxy.get("server") if proxy else None
|
||||||
|
if trust_env and not proxy_server:
|
||||||
|
env_proxies = urllib.request.getproxies()
|
||||||
|
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||||
|
if env_proxy_server:
|
||||||
|
if proxy:
|
||||||
|
proxy["server"] = env_proxy_server
|
||||||
|
else:
|
||||||
|
proxy = {"server": env_proxy_server}
|
||||||
|
|
||||||
|
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||||
|
super().__init__(
|
||||||
|
urls=web_paths,
|
||||||
|
continue_on_failure=continue_on_failure,
|
||||||
|
headless=headless if playwright_ws_url is None else False,
|
||||||
|
remove_selectors=remove_selectors,
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
self.requests_per_second = requests_per_second
|
||||||
|
self.last_request_time = None
|
||||||
|
self.playwright_ws_url = playwright_ws_url
|
||||||
|
self.trust_env = trust_env
|
||||||
|
|
||||||
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
|
"""Safely load URLs synchronously with support for remote browser."""
|
||||||
|
from playwright.sync_api import sync_playwright
|
||||||
|
|
||||||
|
with sync_playwright() as p:
|
||||||
|
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||||
|
if self.playwright_ws_url:
|
||||||
|
browser = p.chromium.connect(self.playwright_ws_url)
|
||||||
|
else:
|
||||||
|
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||||
|
|
||||||
|
for url in self.urls:
|
||||||
|
try:
|
||||||
|
self._safe_process_url_sync(url)
|
||||||
|
page = browser.new_page()
|
||||||
|
response = page.goto(url)
|
||||||
|
if response is None:
|
||||||
|
raise ValueError(f"page.goto() returned None for url {url}")
|
||||||
|
|
||||||
|
text = self.evaluator.evaluate(page, browser, response)
|
||||||
|
metadata = {"source": url}
|
||||||
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
except Exception as e:
|
||||||
|
if self.continue_on_failure:
|
||||||
|
log.exception(e, "Error loading %s", url)
|
||||||
|
continue
|
||||||
|
raise e
|
||||||
|
browser.close()
|
||||||
|
|
||||||
|
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||||
|
"""Safely load URLs asynchronously with support for remote browser."""
|
||||||
|
from playwright.async_api import async_playwright
|
||||||
|
|
||||||
|
async with async_playwright() as p:
|
||||||
|
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||||
|
if self.playwright_ws_url:
|
||||||
|
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||||
|
else:
|
||||||
|
browser = await p.chromium.launch(
|
||||||
|
headless=self.headless, proxy=self.proxy
|
||||||
|
)
|
||||||
|
|
||||||
|
for url in self.urls:
|
||||||
|
try:
|
||||||
|
await self._safe_process_url(url)
|
||||||
|
page = await browser.new_page()
|
||||||
|
response = await page.goto(url)
|
||||||
|
if response is None:
|
||||||
|
raise ValueError(f"page.goto() returned None for url {url}")
|
||||||
|
|
||||||
|
text = await self.evaluator.evaluate_async(page, browser, response)
|
||||||
|
metadata = {"source": url}
|
||||||
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
except Exception as e:
|
||||||
|
if self.continue_on_failure:
|
||||||
|
log.exception(e, "Error loading %s", url)
|
||||||
|
continue
|
||||||
|
raise e
|
||||||
|
await browser.close()
|
||||||
|
|
||||||
|
def _verify_ssl_cert(self, url: str) -> bool:
|
||||||
|
return verify_ssl_cert(url)
|
||||||
|
|
||||||
|
async def _wait_for_rate_limit(self):
|
||||||
|
"""Wait to respect the rate limit if specified."""
|
||||||
|
if self.requests_per_second and self.last_request_time:
|
||||||
|
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||||
|
time_since_last = datetime.now() - self.last_request_time
|
||||||
|
if time_since_last < min_interval:
|
||||||
|
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||||
|
self.last_request_time = datetime.now()
|
||||||
|
|
||||||
|
def _sync_wait_for_rate_limit(self):
|
||||||
|
"""Synchronous version of rate limit wait."""
|
||||||
|
if self.requests_per_second and self.last_request_time:
|
||||||
|
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||||
|
time_since_last = datetime.now() - self.last_request_time
|
||||||
|
if time_since_last < min_interval:
|
||||||
|
time.sleep((min_interval - time_since_last).total_seconds())
|
||||||
|
self.last_request_time = datetime.now()
|
||||||
|
|
||||||
|
async def _safe_process_url(self, url: str) -> bool:
|
||||||
|
"""Perform safety checks before processing a URL."""
|
||||||
|
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||||
|
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||||
|
await self._wait_for_rate_limit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _safe_process_url_sync(self, url: str) -> bool:
|
||||||
|
"""Synchronous version of safety checks."""
|
||||||
|
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||||
|
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||||
|
self._sync_wait_for_rate_limit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SafeWebBaseLoader(WebBaseLoader):
|
class SafeWebBaseLoader(WebBaseLoader):
|
||||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||||
|
|
||||||
|
def __init__(self, trust_env: bool = False, *args, **kwargs):
|
||||||
|
"""Initialize SafeWebBaseLoader
|
||||||
|
Args:
|
||||||
|
trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
||||||
|
using http(s)_proxy environment variables. Defaults to False.
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.trust_env = trust_env
|
||||||
|
|
||||||
|
async def _fetch(
|
||||||
|
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||||
|
) -> str:
|
||||||
|
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||||
|
for i in range(retries):
|
||||||
|
try:
|
||||||
|
kwargs: Dict = dict(
|
||||||
|
headers=self.session.headers,
|
||||||
|
cookies=self.session.cookies.get_dict(),
|
||||||
|
)
|
||||||
|
if not self.session.verify:
|
||||||
|
kwargs["ssl"] = False
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
url, **(self.requests_kwargs | kwargs)
|
||||||
|
) as response:
|
||||||
|
if self.raise_for_status:
|
||||||
|
response.raise_for_status()
|
||||||
|
return await response.text()
|
||||||
|
except aiohttp.ClientConnectionError as e:
|
||||||
|
if i == retries - 1:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
f"Error fetching {url} with attempt "
|
||||||
|
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||||
|
)
|
||||||
|
await asyncio.sleep(cooldown * backoff**i)
|
||||||
|
raise ValueError("retry count exceeded")
|
||||||
|
|
||||||
|
def _unpack_fetch_results(
|
||||||
|
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Unpack fetch results into BeautifulSoup objects."""
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
url = urls[i]
|
||||||
|
if parser is None:
|
||||||
|
if url.endswith(".xml"):
|
||||||
|
parser = "xml"
|
||||||
|
else:
|
||||||
|
parser = self.default_parser
|
||||||
|
self._check_parser(parser)
|
||||||
|
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
async def ascrape_all(
|
||||||
|
self, urls: List[str], parser: Union[str, None] = None
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Async fetch all urls, then return soups for all results."""
|
||||||
|
results = await self.fetch_all(urls)
|
||||||
|
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||||
for path in self.web_paths:
|
for path in self.web_paths:
|
||||||
|
|
@ -65,6 +467,18 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||||
|
|
||||||
# Build metadata
|
# Build metadata
|
||||||
|
metadata = extract_metadata(soup, path)
|
||||||
|
|
||||||
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error and continue with the next URL
|
||||||
|
log.exception(e, "Error loading %s", path)
|
||||||
|
|
||||||
|
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||||
|
"""Async lazy load text from the url(s) in web_path."""
|
||||||
|
results = await self.ascrape_all(self.web_paths)
|
||||||
|
for path, soup in zip(self.web_paths, results):
|
||||||
|
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||||
metadata = {"source": path}
|
metadata = {"source": path}
|
||||||
if title := soup.find("title"):
|
if title := soup.find("title"):
|
||||||
metadata["title"] = title.get_text()
|
metadata["title"] = title.get_text()
|
||||||
|
|
@ -74,24 +488,51 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||||
)
|
)
|
||||||
if html := soup.find("html"):
|
if html := soup.find("html"):
|
||||||
metadata["language"] = html.get("lang", "No language found.")
|
metadata["language"] = html.get("lang", "No language found.")
|
||||||
|
|
||||||
yield Document(page_content=text, metadata=metadata)
|
yield Document(page_content=text, metadata=metadata)
|
||||||
except Exception as e:
|
|
||||||
# Log the error and continue with the next URL
|
async def aload(self) -> list[Document]:
|
||||||
log.error(f"Error loading {path}: {e}")
|
"""Load data into Document objects."""
|
||||||
|
return [document async for document in self.alazy_load()]
|
||||||
|
|
||||||
|
|
||||||
|
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||||
|
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||||
|
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||||
|
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||||
|
|
||||||
|
|
||||||
def get_web_loader(
|
def get_web_loader(
|
||||||
urls: Union[str, Sequence[str]],
|
urls: Union[str, Sequence[str]],
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
requests_per_second: int = 2,
|
requests_per_second: int = 2,
|
||||||
|
trust_env: bool = False,
|
||||||
):
|
):
|
||||||
# Check if the URL is valid
|
# Check if the URLs are valid
|
||||||
if not validate_url(urls):
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
return SafeWebBaseLoader(
|
web_loader_args = {
|
||||||
urls,
|
"web_paths": safe_urls,
|
||||||
verify_ssl=verify_ssl,
|
"verify_ssl": verify_ssl,
|
||||||
requests_per_second=requests_per_second,
|
"requests_per_second": requests_per_second,
|
||||||
continue_on_failure=True,
|
"continue_on_failure": True,
|
||||||
|
"trust_env": trust_env,
|
||||||
|
}
|
||||||
|
|
||||||
|
if PLAYWRIGHT_WS_URI.value:
|
||||||
|
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||||
|
|
||||||
|
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
|
||||||
|
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||||
|
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||||
|
|
||||||
|
# Create the appropriate WebLoader based on the configuration
|
||||||
|
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||||
|
web_loader = WebLoaderClass(**web_loader_args)
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
|
||||||
|
web_loader.__class__.__name__,
|
||||||
|
len(safe_urls),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return web_loader
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import requests
|
import requests
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
|
|
@ -36,6 +37,7 @@ from open_webui.config import (
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
ENV,
|
ENV,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
DEVICE_TYPE,
|
DEVICE_TYPE,
|
||||||
|
|
@ -52,7 +54,7 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||||
|
|
||||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,7 +71,7 @@ from pydub.utils import mediainfo
|
||||||
def is_mp4_audio(file_path):
|
def is_mp4_audio(file_path):
|
||||||
"""Check if the given file is an MP4 audio file."""
|
"""Check if the given file is an MP4 audio file."""
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
print(f"File not found: {file_path}")
|
log.error(f"File not found: {file_path}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
info = mediainfo(file_path)
|
info = mediainfo(file_path)
|
||||||
|
|
@ -86,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
|
||||||
"""Convert MP4 audio file to WAV format."""
|
"""Convert MP4 audio file to WAV format."""
|
||||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||||
audio.export(output_path, format="wav")
|
audio.export(output_path, format="wav")
|
||||||
print(f"Converted {file_path} to {output_path}")
|
log.info(f"Converted {file_path} to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||||
|
|
@ -138,6 +140,7 @@ class STTConfigForm(BaseModel):
|
||||||
ENGINE: str
|
ENGINE: str
|
||||||
MODEL: str
|
MODEL: str
|
||||||
WHISPER_MODEL: str
|
WHISPER_MODEL: str
|
||||||
|
DEEPGRAM_API_KEY: str
|
||||||
|
|
||||||
|
|
||||||
class AudioConfigUpdateForm(BaseModel):
|
class AudioConfigUpdateForm(BaseModel):
|
||||||
|
|
@ -165,6 +168,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||||
"MODEL": request.app.state.config.STT_MODEL,
|
"MODEL": request.app.state.config.STT_MODEL,
|
||||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||||
|
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -190,6 +194,7 @@ async def update_audio_config(
|
||||||
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||||
|
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||||
|
|
||||||
if request.app.state.config.STT_ENGINE == "":
|
if request.app.state.config.STT_ENGINE == "":
|
||||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||||
|
|
@ -214,6 +219,7 @@ async def update_audio_config(
|
||||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||||
"MODEL": request.app.state.config.STT_MODEL,
|
"MODEL": request.app.state.config.STT_MODEL,
|
||||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||||
|
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -260,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
payload["model"] = request.app.state.config.TTS_MODEL
|
payload["model"] = request.app.state.config.TTS_MODEL
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# print(payload)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=timeout, trust_env=True
|
||||||
|
) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||||
json=payload,
|
json=payload,
|
||||||
|
|
@ -318,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=timeout, trust_env=True
|
||||||
|
) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||||
json={
|
json={
|
||||||
|
|
@ -375,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||||
<voice name="{language}">{payload["input"]}</voice>
|
<voice name="{language}">{payload["input"]}</voice>
|
||||||
</speak>"""
|
</speak>"""
|
||||||
async with aiohttp.ClientSession() as session:
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=timeout, trust_env=True
|
||||||
|
) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -453,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
def transcribe(request: Request, file_path):
|
def transcribe(request: Request, file_path):
|
||||||
print("transcribe", file_path)
|
log.info(f"transcribe: {file_path}")
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
file_dir = os.path.dirname(file_path)
|
file_dir = os.path.dirname(file_path)
|
||||||
id = filename.split(".")[0]
|
id = filename.split(".")[0]
|
||||||
|
|
@ -521,6 +535,69 @@ def transcribe(request: Request, file_path):
|
||||||
|
|
||||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||||
|
|
||||||
|
elif request.app.state.config.STT_ENGINE == "deepgram":
|
||||||
|
try:
|
||||||
|
# Determine the MIME type of the file
|
||||||
|
mime, _ = mimetypes.guess_type(file_path)
|
||||||
|
if not mime:
|
||||||
|
mime = "audio/wav" # fallback to wav if undetectable
|
||||||
|
|
||||||
|
# Read the audio file
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_data = f.read()
|
||||||
|
|
||||||
|
# Build headers and parameters
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
|
||||||
|
"Content-Type": mime,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add model if specified
|
||||||
|
params = {}
|
||||||
|
if request.app.state.config.STT_MODEL:
|
||||||
|
params["model"] = request.app.state.config.STT_MODEL
|
||||||
|
|
||||||
|
# Make request to Deepgram API
|
||||||
|
r = requests.post(
|
||||||
|
"https://api.deepgram.com/v1/listen",
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
data=file_data,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
response_data = r.json()
|
||||||
|
|
||||||
|
# Extract transcript from Deepgram response
|
||||||
|
try:
|
||||||
|
transcript = response_data["results"]["channels"][0]["alternatives"][
|
||||||
|
0
|
||||||
|
].get("transcript", "")
|
||||||
|
except (KeyError, IndexError) as e:
|
||||||
|
log.error(f"Malformed response from Deepgram: {str(e)}")
|
||||||
|
raise Exception(
|
||||||
|
"Failed to parse Deepgram response - unexpected response format"
|
||||||
|
)
|
||||||
|
data = {"text": transcript.strip()}
|
||||||
|
|
||||||
|
# Save transcript
|
||||||
|
transcript_file = f"{file_dir}/{id}.json"
|
||||||
|
with open(transcript_file, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
detail = None
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = r.json()
|
||||||
|
if "error" in res:
|
||||||
|
detail = f"External: {res['error'].get('message', '')}"
|
||||||
|
except Exception:
|
||||||
|
detail = f"External: {e}"
|
||||||
|
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||||
|
|
||||||
|
|
||||||
def compress_audio(file_path):
|
def compress_audio(file_path):
|
||||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||||
|
|
@ -602,6 +679,21 @@ def transcription(
|
||||||
def get_available_models(request: Request) -> list[dict]:
|
def get_available_models(request: Request) -> list[dict]:
|
||||||
available_models = []
|
available_models = []
|
||||||
if request.app.state.config.TTS_ENGINE == "openai":
|
if request.app.state.config.TTS_ENGINE == "openai":
|
||||||
|
# Use custom endpoint if not using the official OpenAI API URL
|
||||||
|
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||||
|
"https://api.openai.com"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
available_models = data.get("models", [])
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||||
|
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||||
|
else:
|
||||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||||
try:
|
try:
|
||||||
|
|
@ -633,6 +725,29 @@ def get_available_voices(request) -> dict:
|
||||||
"""Returns {voice_id: voice_name} dict"""
|
"""Returns {voice_id: voice_name} dict"""
|
||||||
available_voices = {}
|
available_voices = {}
|
||||||
if request.app.state.config.TTS_ENGINE == "openai":
|
if request.app.state.config.TTS_ENGINE == "openai":
|
||||||
|
# Use custom endpoint if not using the official OpenAI API URL
|
||||||
|
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||||
|
"https://api.openai.com"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
voices_list = data.get("voices", [])
|
||||||
|
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||||
|
available_voices = {
|
||||||
|
"alloy": "alloy",
|
||||||
|
"echo": "echo",
|
||||||
|
"fable": "fable",
|
||||||
|
"onyx": "onyx",
|
||||||
|
"nova": "nova",
|
||||||
|
"shimmer": "shimmer",
|
||||||
|
}
|
||||||
|
else:
|
||||||
available_voices = {
|
available_voices = {
|
||||||
"alloy": "alloy",
|
"alloy": "alloy",
|
||||||
"echo": "echo",
|
"echo": "echo",
|
||||||
|
|
|
||||||
|
|
@ -25,16 +25,13 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH,
|
WEBUI_AUTH,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
WEBUI_SESSION_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
)
|
)
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse, Response
|
from fastapi.responses import RedirectResponse, Response
|
||||||
from open_webui.config import (
|
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||||
OPENID_PROVIDER_URL,
|
|
||||||
ENABLE_OAUTH_SIGNUP,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
|
|
@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||||
from ldap3 import Server, Connection, NONE, Tls
|
|
||||||
from ldap3.utils.conv import escape_filter_chars
|
if ENABLE_LDAP.value:
|
||||||
|
from ldap3 import Server, Connection, NONE, Tls
|
||||||
|
from ldap3.utils.conv import escape_filter_chars
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -95,8 +94,8 @@ async def get_session_user(
|
||||||
value=token,
|
value=token,
|
||||||
expires=datetime_expires_at,
|
expires=datetime_expires_at,
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
|
|
@ -164,7 +163,7 @@ async def update_password(
|
||||||
############################
|
############################
|
||||||
# LDAP Authentication
|
# LDAP Authentication
|
||||||
############################
|
############################
|
||||||
@router.post("/ldap", response_model=SigninResponse)
|
@router.post("/ldap", response_model=SessionUserResponse)
|
||||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
||||||
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
||||||
|
|
@ -231,9 +230,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
|
|
||||||
entry = connection_app.entries[0]
|
entry = connection_app.entries[0]
|
||||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||||
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||||
if not mail or mail == "" or mail == "[]":
|
if not email or email == "" or email == "[]":
|
||||||
raise HTTPException(400, f"User {form_data.user} does not have mail.")
|
raise HTTPException(400, f"User {form_data.user} does not have email.")
|
||||||
|
else:
|
||||||
|
email = email.lower()
|
||||||
|
|
||||||
cn = str(entry["cn"])
|
cn = str(entry["cn"])
|
||||||
user_dn = entry.entry_dn
|
user_dn = entry.entry_dn
|
||||||
|
|
||||||
|
|
@ -248,17 +250,22 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
if not connection_user.bind():
|
if not connection_user.bind():
|
||||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||||
|
|
||||||
user = Users.get_user_by_email(mail)
|
user = Users.get_user_by_email(email)
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
|
user_count = Users.get_num_users()
|
||||||
|
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if Users.get_num_users() == 0
|
if user_count == 0
|
||||||
else request.app.state.config.DEFAULT_USER_ROLE
|
else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
email=email,
|
||||||
|
password=str(uuid.uuid4()),
|
||||||
|
name=cn,
|
||||||
|
role=role,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -271,7 +278,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||||
|
|
||||||
user = Auths.authenticate_user_by_trusted_header(mail)
|
user = Auths.authenticate_user_by_trusted_header(email)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
token = create_token(
|
token = create_token(
|
||||||
|
|
@ -288,6 +295,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_permissions = get_permissions(
|
||||||
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
|
|
@ -296,6 +307,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
"profile_image_url": user.profile_image_url,
|
"profile_image_url": user.profile_image_url,
|
||||||
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
@ -378,8 +390,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
value=token,
|
value=token,
|
||||||
expires=datetime_expires_at,
|
expires=datetime_expires_at,
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
|
|
@ -408,6 +420,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
@router.post("/signup", response_model=SessionUserResponse)
|
@router.post("/signup", response_model=SessionUserResponse)
|
||||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
if WEBUI_AUTH:
|
if WEBUI_AUTH:
|
||||||
if (
|
if (
|
||||||
not request.app.state.config.ENABLE_SIGNUP
|
not request.app.state.config.ENABLE_SIGNUP
|
||||||
|
|
@ -422,6 +435,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_count = Users.get_num_users()
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
|
|
@ -432,12 +446,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
if Users.get_num_users() == 0
|
|
||||||
else request.app.state.config.DEFAULT_USER_ROLE
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if Users.get_num_users() == 0:
|
if user_count == 0:
|
||||||
# Disable signup after the first user is created
|
# Disable signup after the first user is created
|
||||||
request.app.state.config.ENABLE_SIGNUP = False
|
request.app.state.config.ENABLE_SIGNUP = False
|
||||||
|
|
||||||
|
|
@ -473,12 +485,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
value=token,
|
value=token,
|
||||||
expires=datetime_expires_at,
|
expires=datetime_expires_at,
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.app.state.config.WEBHOOK_URL:
|
if request.app.state.config.WEBHOOK_URL:
|
||||||
post_webhook(
|
post_webhook(
|
||||||
|
request.app.state.WEBUI_NAME,
|
||||||
request.app.state.config.WEBHOOK_URL,
|
request.app.state.config.WEBHOOK_URL,
|
||||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
{
|
{
|
||||||
|
|
@ -525,7 +538,8 @@ async def signout(request: Request, response: Response):
|
||||||
if logout_url:
|
if logout_url:
|
||||||
response.delete_cookie("oauth_id_token")
|
response.delete_cookie("oauth_id_token")
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
headers=response.headers,
|
||||||
|
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -591,7 +605,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||||
admin_name = None
|
admin_name = None
|
||||||
|
|
||||||
print(admin_email, admin_name)
|
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||||
|
|
||||||
if admin_email:
|
if admin_email:
|
||||||
admin = Users.get_user_by_email(admin_email)
|
admin = Users.get_user_by_email(admin_email)
|
||||||
|
|
|
||||||
|
|
@ -192,7 +192,7 @@ async def get_channel_messages(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||||
users = get_users_with_access("read", channel.access_control)
|
users = get_users_with_access("read", channel.access_control)
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
|
|
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
|
||||||
|
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
post_webhook(
|
post_webhook(
|
||||||
|
name,
|
||||||
webhook_url,
|
webhook_url,
|
||||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||||
{
|
{
|
||||||
|
|
@ -302,6 +303,7 @@ async def post_new_message(
|
||||||
|
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
send_notification,
|
send_notification,
|
||||||
|
request.app.state.WEBUI_NAME,
|
||||||
request.app.state.config.WEBUI_URL,
|
request.app.state.config.WEBUI_URL,
|
||||||
channel,
|
channel,
|
||||||
message,
|
message,
|
||||||
|
|
|
||||||
|
|
@ -444,15 +444,21 @@ async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class CloneForm(BaseModel):
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def clone_chat_by_id(
|
||||||
|
form_data: CloneForm, id: str, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
if chat:
|
if chat:
|
||||||
updated_chat = {
|
updated_chat = {
|
||||||
**chat.chat,
|
**chat.chat,
|
||||||
"originalChatId": chat.id,
|
"originalChatId": chat.id,
|
||||||
"branchPointMessageId": chat.chat["history"]["currentId"],
|
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||||
"title": f"Clone of {chat.title}",
|
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,140 @@ async def export_config(user=Depends(get_admin_user)):
|
||||||
return get_config()
|
return get_config()
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# Direct Connections Config
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class DirectConnectionsConfigForm(BaseModel):
|
||||||
|
ENABLE_DIRECT_CONNECTIONS: bool
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||||
|
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||||
|
return {
|
||||||
|
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||||
|
async def set_direct_connections_config(
|
||||||
|
request: Request,
|
||||||
|
form_data: DirectConnectionsConfigForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
):
|
||||||
|
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
|
||||||
|
form_data.ENABLE_DIRECT_CONNECTIONS
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# CodeInterpreterConfig
|
||||||
|
############################
|
||||||
|
class CodeInterpreterConfigForm(BaseModel):
|
||||||
|
CODE_EXECUTION_ENGINE: str
|
||||||
|
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||||
|
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||||
|
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||||
|
ENABLE_CODE_INTERPRETER: bool
|
||||||
|
CODE_INTERPRETER_ENGINE: str
|
||||||
|
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||||
|
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||||
|
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||||
|
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||||
|
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||||
|
return {
|
||||||
|
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||||
|
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||||
|
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||||
|
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||||
|
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||||
|
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||||
|
async def set_code_execution_config(
|
||||||
|
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
|
||||||
|
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||||
|
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||||
|
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||||
|
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||||
|
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||||
|
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||||
|
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||||
|
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||||
|
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
|
||||||
|
form_data.CODE_INTERPRETER_JUPYTER_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
|
||||||
|
form_data.CODE_INTERPRETER_JUPYTER_AUTH
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||||
|
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||||
|
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||||
|
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||||
|
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||||
|
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||||
|
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||||
|
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||||
|
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||||
|
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||||
|
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# SetDefaultModels
|
# SetDefaultModels
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -3,30 +3,23 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel
|
|
||||||
import mimetypes
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from open_webui.storage.provider import Storage
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.models.files import (
|
from open_webui.models.files import (
|
||||||
FileForm,
|
FileForm,
|
||||||
FileModel,
|
FileModel,
|
||||||
FileModelResponse,
|
FileModelResponse,
|
||||||
Files,
|
Files,
|
||||||
)
|
)
|
||||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||||
|
from open_webui.routers.audio import transcribe
|
||||||
from open_webui.config import UPLOAD_DIR
|
from open_webui.storage.provider import Storage
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
@ -41,7 +34,10 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.post("/", response_model=FileModelResponse)
|
@router.post("/", response_model=FileModelResponse)
|
||||||
def upload_file(
|
def upload_file(
|
||||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
file_metadata: dict = {},
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
try:
|
try:
|
||||||
|
|
@ -65,13 +61,29 @@ def upload_file(
|
||||||
"name": name,
|
"name": name,
|
||||||
"content_type": file.content_type,
|
"content_type": file.content_type,
|
||||||
"size": len(contents),
|
"size": len(contents),
|
||||||
|
"data": file_metadata,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process_file(request, ProcessFileForm(file_id=id))
|
if file.content_type in [
|
||||||
|
"audio/mpeg",
|
||||||
|
"audio/wav",
|
||||||
|
"audio/ogg",
|
||||||
|
"audio/x-m4a",
|
||||||
|
]:
|
||||||
|
file_path = Storage.get_file(file_path)
|
||||||
|
result = transcribe(request, file_path)
|
||||||
|
process_file(
|
||||||
|
request,
|
||||||
|
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||||
|
|
||||||
file_item = Files.get_file_by_id(id=id)
|
file_item = Files.get_file_by_id(id=id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -126,7 +138,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
||||||
Storage.delete_all_files()
|
Storage.delete_all_files()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error deleting files")
|
log.error("Error deleting files")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||||
|
|
@ -193,7 +205,9 @@ async def update_file_data_content_by_id(
|
||||||
if file and (file.user_id == user.id or user.role == "admin"):
|
if file and (file.user_id == user.id or user.role == "admin"):
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
request,
|
||||||
|
ProcessFileForm(file_id=id, content=form_data.content),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
file = Files.get_file_by_id(id=id)
|
file = Files.get_file_by_id(id=id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -227,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
filename = file.meta.get("name", file.filename)
|
filename = file.meta.get("name", file.filename)
|
||||||
encoded_filename = quote(filename) # RFC5987 encoding
|
encoded_filename = quote(filename) # RFC5987 encoding
|
||||||
|
|
||||||
|
content_type = file.meta.get("content_type")
|
||||||
|
filename = file.meta.get("name", file.filename)
|
||||||
|
encoded_filename = quote(filename)
|
||||||
headers = {}
|
headers = {}
|
||||||
if file.meta.get("content_type") not in [
|
|
||||||
"application/pdf",
|
|
||||||
"text/plain",
|
|
||||||
]:
|
|
||||||
headers = {
|
|
||||||
**headers,
|
|
||||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
|
||||||
}
|
|
||||||
|
|
||||||
return FileResponse(file_path, headers=headers)
|
if content_type == "application/pdf" or filename.lower().endswith(
|
||||||
|
".pdf"
|
||||||
|
):
|
||||||
|
headers["Content-Disposition"] = (
|
||||||
|
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||||
|
)
|
||||||
|
content_type = "application/pdf"
|
||||||
|
elif content_type != "text/plain":
|
||||||
|
headers["Content-Disposition"] = (
|
||||||
|
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -246,7 +267,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error getting file content")
|
log.error("Error getting file content")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||||
|
|
@ -268,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
# Check if the file already exists in the cache
|
# Check if the file already exists in the cache
|
||||||
if file_path.is_file():
|
if file_path.is_file():
|
||||||
print(f"file_path: {file_path}")
|
log.info(f"file_path: {file_path}")
|
||||||
return FileResponse(file_path)
|
return FileResponse(file_path)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -277,7 +298,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error getting file content")
|
log.error("Error getting file content")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||||
|
|
@ -353,7 +374,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
Storage.delete_file(file.path)
|
Storage.delete_file(file.path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error deleting files")
|
log.error("Error deleting files")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -68,7 +74,7 @@ async def create_new_function(
|
||||||
|
|
||||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||||
|
|
||||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
|
|
@ -79,7 +85,7 @@ async def create_new_function(
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to create a new function: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
@ -183,7 +189,7 @@ async def update_function_by_id(
|
||||||
FUNCTIONS[id] = function_module
|
FUNCTIONS[id] = function_module
|
||||||
|
|
||||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||||
print(updated)
|
log.debug(updated)
|
||||||
|
|
||||||
function = Functions.update_function_by_id(id, updated)
|
function = Functions.update_function_by_id(id, updated)
|
||||||
|
|
||||||
|
|
@ -299,7 +305,7 @@ async def update_function_valves_by_id(
|
||||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||||
return valves.model_dump()
|
return valves.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error updating function values by id {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
|
||||||
)
|
)
|
||||||
return user_valves.model_dump()
|
return user_valves.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
|
||||||
16
backend/open_webui/routers/groups.py
Normal file → Executable file
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
from open_webui.models.groups import (
|
from open_webui.models.groups import (
|
||||||
|
|
@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
||||||
from open_webui.config import CACHE_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[GroupResponse])
|
@router.post("/create", response_model=Optional[GroupResponse])
|
||||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
try:
|
try:
|
||||||
group = Groups.insert_new_group(user.id, form_data)
|
group = Groups.insert_new_group(user.id, form_data)
|
||||||
if group:
|
if group:
|
||||||
|
|
@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error creating a new group: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
@ -94,7 +100,7 @@ async def update_group_by_id(
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error updating group {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error deleting group {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,31 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import CACHE_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
|
from open_webui.routers.files import upload_file
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.images.comfyui import (
|
from open_webui.utils.images.comfyui import (
|
||||||
ComfyUIGenerateImageForm,
|
ComfyUIGenerateImageForm,
|
||||||
ComfyUIWorkflow,
|
ComfyUIWorkflow,
|
||||||
comfyui_generate_image,
|
comfyui_generate_image,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||||
|
|
||||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||||
},
|
},
|
||||||
|
"gemini": {
|
||||||
|
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||||
|
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -84,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
||||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiConfigForm(BaseModel):
|
||||||
|
GEMINI_API_BASE_URL: str
|
||||||
|
GEMINI_API_KEY: str
|
||||||
|
|
||||||
|
|
||||||
class ConfigForm(BaseModel):
|
class ConfigForm(BaseModel):
|
||||||
enabled: bool
|
enabled: bool
|
||||||
engine: str
|
engine: str
|
||||||
|
|
@ -91,6 +94,7 @@ class ConfigForm(BaseModel):
|
||||||
openai: OpenAIConfigForm
|
openai: OpenAIConfigForm
|
||||||
automatic1111: Automatic1111ConfigForm
|
automatic1111: Automatic1111ConfigForm
|
||||||
comfyui: ComfyUIConfigForm
|
comfyui: ComfyUIConfigForm
|
||||||
|
gemini: GeminiConfigForm
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/update")
|
@router.post("/config/update")
|
||||||
|
|
@ -109,6 +113,11 @@ async def update_config(
|
||||||
)
|
)
|
||||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||||
|
|
||||||
|
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||||
|
form_data.gemini.GEMINI_API_BASE_URL
|
||||||
|
)
|
||||||
|
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||||
|
|
||||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||||
)
|
)
|
||||||
|
|
@ -135,6 +144,8 @@ async def update_config(
|
||||||
request.app.state.config.COMFYUI_BASE_URL = (
|
request.app.state.config.COMFYUI_BASE_URL = (
|
||||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||||
)
|
)
|
||||||
|
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||||
|
|
||||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||||
|
|
@ -161,6 +172,10 @@ async def update_config(
|
||||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||||
},
|
},
|
||||||
|
"gemini": {
|
||||||
|
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||||
|
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -190,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
||||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||||
|
|
||||||
|
headers = None
|
||||||
|
if request.app.state.config.COMFYUI_API_KEY:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return True
|
return True
|
||||||
|
|
@ -230,6 +253,12 @@ def get_image_model(request):
|
||||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
else "dall-e-2"
|
else "dall-e-2"
|
||||||
)
|
)
|
||||||
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||||
|
return (
|
||||||
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
|
else "imagen-3.0-generate-002"
|
||||||
|
)
|
||||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||||
return (
|
return (
|
||||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
|
|
@ -271,7 +300,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
||||||
async def update_image_config(
|
async def update_image_config(
|
||||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
set_image_model(request, form_data.MODEL)
|
set_image_model(request, form_data.MODEL)
|
||||||
|
|
||||||
pattern = r"^\d+x\d+$"
|
pattern = r"^\d+x\d+$"
|
||||||
|
|
@ -306,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||||
]
|
]
|
||||||
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||||
|
return [
|
||||||
|
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||||
|
]
|
||||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||||
# TODO - get models from comfyui
|
# TODO - get models from comfyui
|
||||||
headers = {
|
headers = {
|
||||||
|
|
@ -329,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
if model_node_id:
|
if model_node_id:
|
||||||
model_list_key = None
|
model_list_key = None
|
||||||
|
|
||||||
print(workflow[model_node_id]["class_type"])
|
log.info(workflow[model_node_id]["class_type"])
|
||||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||||
"required"
|
"required"
|
||||||
]:
|
]:
|
||||||
|
|
@ -383,40 +415,22 @@ class GenerateImageForm(BaseModel):
|
||||||
negative_prompt: Optional[str] = None
|
negative_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def save_b64_image(b64_str):
|
def load_b64_image_data(b64_str):
|
||||||
try:
|
try:
|
||||||
image_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if "," in b64_str:
|
if "," in b64_str:
|
||||||
header, encoded = b64_str.split(",", 1)
|
header, encoded = b64_str.split(",", 1)
|
||||||
mime_type = header.split(";")[0]
|
mime_type = header.split(";")[0]
|
||||||
|
|
||||||
img_data = base64.b64decode(encoded)
|
img_data = base64.b64decode(encoded)
|
||||||
image_format = mimetypes.guess_extension(mime_type)
|
|
||||||
|
|
||||||
image_filename = f"{image_id}{image_format}"
|
|
||||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(img_data)
|
|
||||||
return image_filename
|
|
||||||
else:
|
else:
|
||||||
image_filename = f"{image_id}.png"
|
mime_type = "image/png"
|
||||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
|
||||||
|
|
||||||
img_data = base64.b64decode(b64_str)
|
img_data = base64.b64decode(b64_str)
|
||||||
|
return img_data, mime_type
|
||||||
# Write the image data to a file
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(img_data)
|
|
||||||
return image_filename
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error saving image: {e}")
|
log.exception(f"Error loading image data: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_url_image(url, headers=None):
|
def load_url_image_data(url, headers=None):
|
||||||
image_id = str(uuid.uuid4())
|
|
||||||
try:
|
try:
|
||||||
if headers:
|
if headers:
|
||||||
r = requests.get(url, headers=headers)
|
r = requests.get(url, headers=headers)
|
||||||
|
|
@ -426,18 +440,7 @@ def save_url_image(url, headers=None):
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
if r.headers["content-type"].split("/")[0] == "image":
|
if r.headers["content-type"].split("/")[0] == "image":
|
||||||
mime_type = r.headers["content-type"]
|
mime_type = r.headers["content-type"]
|
||||||
image_format = mimetypes.guess_extension(mime_type)
|
return r.content, mime_type
|
||||||
|
|
||||||
if not image_format:
|
|
||||||
raise ValueError("Could not determine image type from MIME type")
|
|
||||||
|
|
||||||
image_filename = f"{image_id}{image_format}"
|
|
||||||
|
|
||||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
|
||||||
with open(file_path, "wb") as image_file:
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
image_file.write(chunk)
|
|
||||||
return image_filename
|
|
||||||
else:
|
else:
|
||||||
log.error("Url does not point to an image.")
|
log.error("Url does not point to an image.")
|
||||||
return None
|
return None
|
||||||
|
|
@ -447,6 +450,20 @@ def save_url_image(url, headers=None):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||||
|
image_format = mimetypes.guess_extension(content_type)
|
||||||
|
file = UploadFile(
|
||||||
|
file=io.BytesIO(image_data),
|
||||||
|
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
|
||||||
|
headers={
|
||||||
|
"content-type": content_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||||
|
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/generations")
|
@router.post("/generations")
|
||||||
async def image_generations(
|
async def image_generations(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
@ -500,12 +517,49 @@ async def image_generations(
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
for image in res["data"]:
|
for image in res["data"]:
|
||||||
image_filename = save_b64_image(image["b64_json"])
|
if "url" in image:
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
image_data, content_type = load_url_image_data(
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
image["url"], headers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||||
|
|
||||||
with open(file_body_path, "w") as f:
|
url = upload_image(request, data, image_data, content_type, user)
|
||||||
json.dump(data, f)
|
images.append({"url": url})
|
||||||
|
return images
|
||||||
|
|
||||||
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||||
|
headers = {}
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||||
|
|
||||||
|
model = get_image_model(request)
|
||||||
|
data = {
|
||||||
|
"instances": {"prompt": form_data.prompt},
|
||||||
|
"parameters": {
|
||||||
|
"sampleCount": form_data.n,
|
||||||
|
"outputOptions": {"mimeType": "image/png"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use asyncio.to_thread for the requests.post call
|
||||||
|
r = await asyncio.to_thread(
|
||||||
|
requests.post,
|
||||||
|
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||||
|
json=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
res = r.json()
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for image in res["predictions"]:
|
||||||
|
image_data, content_type = load_b64_image_data(
|
||||||
|
image["bytesBase64Encoded"]
|
||||||
|
)
|
||||||
|
url = upload_image(request, data, image_data, content_type, user)
|
||||||
|
images.append({"url": url})
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
@ -552,14 +606,15 @@ async def image_generations(
|
||||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||||
}
|
}
|
||||||
|
|
||||||
image_filename = save_url_image(image["url"], headers)
|
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
url = upload_image(
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
request,
|
||||||
|
form_data.model_dump(exclude_none=True),
|
||||||
with open(file_body_path, "w") as f:
|
image_data,
|
||||||
json.dump(form_data.model_dump(exclude_none=True), f)
|
content_type,
|
||||||
|
user,
|
||||||
log.debug(f"images: {images}")
|
)
|
||||||
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
elif (
|
elif (
|
||||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||||
|
|
@ -604,13 +659,15 @@ async def image_generations(
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
for image in res["images"]:
|
for image in res["images"]:
|
||||||
image_filename = save_b64_image(image)
|
image_data, content_type = load_b64_image_data(image)
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
url = upload_image(
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
request,
|
||||||
|
{**data, "info": res["info"]},
|
||||||
with open(file_body_path, "w") as f:
|
image_data,
|
||||||
json.dump({**data, "info": res["info"]}, f)
|
content_type,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = e
|
error = e
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,11 @@ def add_file_to_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
@ -285,7 +289,9 @@ def add_file_to_knowledge_by_id(
|
||||||
# Add content to the vector database
|
# Add content to the vector database
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
request,
|
||||||
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
|
|
@ -342,7 +348,12 @@ def update_file_from_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
@ -363,7 +374,9 @@ def update_file_from_knowledge_by_id(
|
||||||
# Add content to the vector database
|
# Add content to the vector database
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
request,
|
||||||
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -406,7 +419,11 @@ def remove_file_from_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
@ -429,10 +446,6 @@ def remove_file_from_knowledge_by_id(
|
||||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||||
|
|
||||||
# Delete physical file
|
|
||||||
if file.path:
|
|
||||||
Storage.delete_file(file.path)
|
|
||||||
|
|
||||||
# Delete file from database
|
# Delete file from database
|
||||||
Files.delete_file_by_id(form_data.file_id)
|
Files.delete_file_by_id(form_data.file_id)
|
||||||
|
|
||||||
|
|
@ -484,7 +497,11 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
@ -543,7 +560,11 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
@ -582,14 +603,18 @@ def add_files_to_knowledge_batch(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
knowledge.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", knowledge.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get files content
|
# Get files content
|
||||||
print(f"files/batch/add - {len(form_data)} files")
|
log.info(f"files/batch/add - {len(form_data)} files")
|
||||||
files: List[FileModel] = []
|
files: List[FileModel] = []
|
||||||
for form in form_data:
|
for form in form_data:
|
||||||
file = Files.get_file_by_id(form.file_id)
|
file = Files.get_file_by_id(form.file_id)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ async def add_memory(
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||||
"metadata": {"created_at": memory.created_at},
|
"metadata": {"created_at": memory.created_at},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -82,7 +82,7 @@ async def query_memory(
|
||||||
):
|
):
|
||||||
results = VECTOR_DB_CLIENT.search(
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||||
limit=form_data.k,
|
limit=form_data.k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
|
|
@ -160,7 +160,9 @@ async def update_memory_by_id(
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
memory.content, user
|
||||||
|
),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
|
|
|
||||||
|
|
@ -183,7 +183,11 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
user.role != "admin"
|
||||||
|
and model.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", model.access_control)
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,14 @@ import re
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
from open_webui.env import (
|
||||||
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
|
|
@ -28,7 +31,7 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, validator
|
||||||
from starlette.background import BackgroundTask
|
from starlette.background import BackgroundTask
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,7 +55,7 @@ from open_webui.env import (
|
||||||
ENV,
|
ENV,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
)
|
)
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
@ -68,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
async def send_get_request(url, key=None):
|
async def send_get_request(url, key=None, user: UserModel = None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -98,6 +115,7 @@ async def send_post_request(
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
content_type: Optional[str] = None,
|
content_type: Optional[str] = None,
|
||||||
|
user: UserModel = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
|
|
@ -112,6 +130,16 @@ async def send_post_request(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -188,12 +216,24 @@ async def verify_connection(
|
||||||
key = form_data.key
|
key = form_data.key
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{url}/api/version",
|
f"{url}/api/version",
|
||||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
detail = f"HTTP Error: {r.status}"
|
detail = f"HTTP Error: {r.status}"
|
||||||
|
|
@ -256,7 +296,7 @@ async def update_config(
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=3)
|
@cached(ttl=3)
|
||||||
async def get_all_models(request: Request):
|
async def get_all_models(request: Request, user: UserModel = None):
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||||
request_tasks = []
|
request_tasks = []
|
||||||
|
|
@ -264,7 +304,7 @@ async def get_all_models(request: Request):
|
||||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||||
):
|
):
|
||||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||||
else:
|
else:
|
||||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||||
str(idx),
|
str(idx),
|
||||||
|
|
@ -277,7 +317,9 @@ async def get_all_models(request: Request):
|
||||||
key = api_config.get("key", None)
|
key = api_config.get("key", None)
|
||||||
|
|
||||||
if enable:
|
if enable:
|
||||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
request_tasks.append(
|
||||||
|
send_get_request(f"{url}/api/tags", key, user=user)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||||
|
|
||||||
|
|
@ -362,7 +404,7 @@ async def get_ollama_tags(
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
|
|
@ -372,7 +414,19 @@ async def get_ollama_tags(
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="GET",
|
method="GET",
|
||||||
url=f"{url}/api/tags",
|
url=f"{url}/api/tags",
|
||||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -395,7 +449,7 @@ async def get_ollama_tags(
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
models["models"] = get_filtered_models(models, user)
|
models["models"] = await get_filtered_models(models, user)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
@ -479,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
|
||||||
url, {}
|
url, {}
|
||||||
), # Legacy support
|
), # Legacy support
|
||||||
).get("key", None),
|
).get("key", None),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||||
]
|
]
|
||||||
|
|
@ -511,6 +566,7 @@ async def pull_model(
|
||||||
url=f"{url}/api/pull",
|
url=f"{url}/api/pull",
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -529,7 +585,7 @@ async def push_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name in models:
|
if form_data.name in models:
|
||||||
|
|
@ -547,6 +603,7 @@ async def push_model(
|
||||||
url=f"{url}/api/push",
|
url=f"{url}/api/push",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -573,6 +630,7 @@ async def create_model(
|
||||||
url=f"{url}/api/create",
|
url=f"{url}/api/create",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -590,7 +648,7 @@ async def copy_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.source in models:
|
if form_data.source in models:
|
||||||
|
|
@ -611,6 +669,16 @@ async def copy_model(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -645,7 +713,7 @@ async def delete_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name in models:
|
if form_data.name in models:
|
||||||
|
|
@ -667,6 +735,16 @@ async def delete_model(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -695,7 +773,7 @@ async def delete_model(
|
||||||
async def show_model_info(
|
async def show_model_info(
|
||||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name not in models:
|
if form_data.name not in models:
|
||||||
|
|
@ -716,6 +794,16 @@ async def show_model_info(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -759,7 +847,7 @@ async def embed(
|
||||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -785,6 +873,16 @@ async def embed(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -828,7 +926,7 @@ async def embeddings(
|
||||||
log.info(f"generate_ollama_embeddings {form_data}")
|
log.info(f"generate_ollama_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -854,6 +952,16 @@ async def embeddings(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -903,7 +1011,7 @@ async def generate_completion(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -933,23 +1041,39 @@ async def generate_completion(
|
||||||
url=f"{url}/api/generate",
|
url=f"{url}/api/generate",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[list[dict]] = None
|
||||||
images: Optional[list[str]] = None
|
images: Optional[list[str]] = None
|
||||||
|
|
||||||
|
@validator("content", pre=True)
|
||||||
|
@classmethod
|
||||||
|
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||||
|
# Raise an error if both 'content' and 'tool_calls' are None
|
||||||
|
if field_value is None and (
|
||||||
|
"tool_calls" not in values or values["tool_calls"] is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of 'content' or 'tool_calls' must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
class GenerateChatCompletionForm(BaseModel):
|
class GenerateChatCompletionForm(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[ChatMessage]
|
messages: list[ChatMessage]
|
||||||
format: Optional[dict] = None
|
format: Optional[Union[dict, str]] = None
|
||||||
options: Optional[dict] = None
|
options: Optional[dict] = None
|
||||||
template: Optional[str] = None
|
template: Optional[str] = None
|
||||||
stream: Optional[bool] = True
|
stream: Optional[bool] = True
|
||||||
keep_alive: Optional[Union[int, str]] = None
|
keep_alive: Optional[Union[int, str]] = None
|
||||||
|
tools: Optional[list[dict]] = None
|
||||||
|
|
||||||
|
|
||||||
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
||||||
|
|
@ -977,6 +1101,7 @@ async def generate_chat_completion(
|
||||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
bypass_filter = True
|
bypass_filter = True
|
||||||
|
|
||||||
|
metadata = form_data.pop("metadata", None)
|
||||||
try:
|
try:
|
||||||
form_data = GenerateChatCompletionForm(**form_data)
|
form_data = GenerateChatCompletionForm(**form_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1006,7 +1131,7 @@ async def generate_chat_completion(
|
||||||
payload["options"] = apply_model_params_to_body_ollama(
|
payload["options"] = apply_model_params_to_body_ollama(
|
||||||
params, payload["options"]
|
params, payload["options"]
|
||||||
)
|
)
|
||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if not bypass_filter and user.role == "user":
|
if not bypass_filter and user.role == "user":
|
||||||
|
|
@ -1046,6 +1171,7 @@ async def generate_chat_completion(
|
||||||
stream=form_data.stream,
|
stream=form_data.stream,
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
content_type="application/x-ndjson",
|
content_type="application/x-ndjson",
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1148,6 +1274,7 @@ async def generate_openai_completion(
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
stream=payload.get("stream", False),
|
stream=payload.get("stream", False),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1159,6 +1286,8 @@ async def generate_openai_chat_completion(
|
||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
metadata = form_data.pop("metadata", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_form = OpenAIChatCompletionForm(**form_data)
|
completion_form = OpenAIChatCompletionForm(**form_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1185,7 +1314,7 @@ async def generate_openai_chat_completion(
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
payload = apply_model_params_to_body_openai(params, payload)
|
payload = apply_model_params_to_body_openai(params, payload)
|
||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
|
|
@ -1224,6 +1353,7 @@ async def generate_openai_chat_completion(
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
stream=payload.get("stream", False),
|
stream=payload.get("stream", False),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1237,7 +1367,7 @@ async def get_openai_models(
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
model_list = await get_all_models(request)
|
model_list = await get_all_models(request, user=user)
|
||||||
models = [
|
models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
|
|
@ -1405,9 +1535,10 @@ async def download_model(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Progress bar does not reflect size & duration of upload.
|
||||||
@router.post("/models/upload")
|
@router.post("/models/upload")
|
||||||
@router.post("/models/upload/{url_idx}")
|
@router.post("/models/upload/{url_idx}")
|
||||||
def upload_model(
|
async def upload_model(
|
||||||
request: Request,
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
|
|
@ -1416,60 +1547,86 @@ def upload_model(
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
url_idx = 0
|
url_idx = 0
|
||||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||||
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
# --- P1: save file locally ---
|
||||||
|
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||||
|
with open(file_path, "wb") as out_f:
|
||||||
|
while True:
|
||||||
|
chunk = file.file.read(chunk_size)
|
||||||
|
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
out_f.write(chunk)
|
||||||
|
|
||||||
# Save file in chunks
|
async def file_process_stream():
|
||||||
with open(file_path, "wb+") as f:
|
|
||||||
for chunk in file.file:
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
def file_process_stream():
|
|
||||||
nonlocal ollama_url
|
nonlocal ollama_url
|
||||||
total_size = os.path.getsize(file_path)
|
total_size = os.path.getsize(file_path)
|
||||||
chunk_size = 1024 * 1024
|
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||||
|
|
||||||
|
# --- P2: SSE progress + calculate sha256 hash ---
|
||||||
|
file_hash = calculate_sha256(file_path, chunk_size)
|
||||||
|
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||||
try:
|
try:
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
total = 0
|
bytes_read = 0
|
||||||
done = False
|
while chunk := f.read(chunk_size):
|
||||||
|
bytes_read += len(chunk)
|
||||||
while not done:
|
progress = round(bytes_read / total_size * 100, 2)
|
||||||
chunk = f.read(chunk_size)
|
data_msg = {
|
||||||
if not chunk:
|
|
||||||
done = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
total += len(chunk)
|
|
||||||
progress = round((total / total_size) * 100, 2)
|
|
||||||
|
|
||||||
res = {
|
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
"total": total_size,
|
"total": total_size,
|
||||||
"completed": total,
|
"completed": bytes_read,
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(res)}\n\n"
|
yield f"data: {json.dumps(data_msg)}\n\n"
|
||||||
|
|
||||||
if done:
|
# --- P3: Upload to ollama /api/blobs ---
|
||||||
f.seek(0)
|
with open(file_path, "rb") as f:
|
||||||
hashed = calculate_sha256(f)
|
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
|
||||||
f.seek(0)
|
|
||||||
|
|
||||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
|
||||||
response = requests.post(url, data=f)
|
response = requests.post(url, data=f)
|
||||||
|
|
||||||
if response.ok:
|
if response.ok:
|
||||||
res = {
|
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||||
"done": done,
|
# Remove local file
|
||||||
"blob": f"sha256:{hashed}",
|
|
||||||
"name": file.filename,
|
|
||||||
}
|
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
yield f"data: {json.dumps(res)}\n\n"
|
|
||||||
|
# Create model in ollama
|
||||||
|
model_name, ext = os.path.splitext(file.filename)
|
||||||
|
log.info(f"Created Model: {model_name}") # DEBUG
|
||||||
|
|
||||||
|
create_payload = {
|
||||||
|
"model": model_name,
|
||||||
|
# Reference the file by its original name => the uploaded blob's digest
|
||||||
|
"files": {file.filename: f"sha256:{file_hash}"},
|
||||||
|
}
|
||||||
|
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||||
|
|
||||||
|
# Call ollama /api/create
|
||||||
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||||
|
create_resp = requests.post(
|
||||||
|
url=f"{ollama_url}/api/create",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data=json.dumps(create_payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
if create_resp.ok:
|
||||||
|
log.info(f"API SUCCESS!") # DEBUG
|
||||||
|
done_msg = {
|
||||||
|
"done": True,
|
||||||
|
"blob": f"sha256:{file_hash}",
|
||||||
|
"name": file.filename,
|
||||||
|
"model_created": model_name,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Ollama: Could not create blob, Please try again."
|
f"Failed to create model in Ollama. {create_resp.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception("Ollama: Could not create blob, Please try again.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
res = {"error": str(e)}
|
res = {"error": str(e)}
|
||||||
yield f"data: {json.dumps(res)}\n\n"
|
yield f"data: {json.dumps(res)}\n\n"
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,11 @@ from open_webui.config import (
|
||||||
)
|
)
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
)
|
)
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||||
|
|
@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
async def send_get_request(url, key=None):
|
async def send_get_request(url, key=None, user: UserModel = None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
url,
|
||||||
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -75,18 +89,24 @@ async def cleanup_response(
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
def openai_o1_handler(payload):
|
def openai_o1_o3_handler(payload):
|
||||||
"""
|
"""
|
||||||
Handle O1 specific parameters
|
Handle o1, o3 specific parameters
|
||||||
"""
|
"""
|
||||||
if "max_tokens" in payload:
|
if "max_tokens" in payload:
|
||||||
# Remove "max_tokens" from the payload
|
# Remove "max_tokens" from the payload
|
||||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||||
del payload["max_tokens"]
|
del payload["max_tokens"]
|
||||||
|
|
||||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
# Fix: o1 and o3 do not support the "system" role directly.
|
||||||
|
# For older models like "o1-mini" or "o1-preview", use role "user".
|
||||||
|
# For newer o1/o3 models, replace "system" with "developer".
|
||||||
if payload["messages"][0]["role"] == "system":
|
if payload["messages"][0]["role"] == "system":
|
||||||
|
model_lower = payload["model"].lower()
|
||||||
|
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||||
payload["messages"][0]["role"] = "user"
|
payload["messages"][0]["role"] = "user"
|
||||||
|
else:
|
||||||
|
payload["messages"][0]["role"] = "developer"
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
@ -172,7 +192,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
name = hashlib.sha256(body).hexdigest()
|
name = hashlib.sha256(body).hexdigest()
|
||||||
|
|
||||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||||
|
|
@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
async def get_all_models_responses(request: Request) -> list:
|
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
|
||||||
):
|
):
|
||||||
request_tasks.append(
|
request_tasks.append(
|
||||||
send_get_request(
|
send_get_request(
|
||||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
f"{url}/models",
|
||||||
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
|
||||||
send_get_request(
|
send_get_request(
|
||||||
f"{url}/models",
|
f"{url}/models",
|
||||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=3)
|
@cached(ttl=3)
|
||||||
async def get_all_models(request: Request) -> dict[str, list]:
|
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
|
|
||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return {"data": []}
|
return {"data": []}
|
||||||
|
|
||||||
responses = await get_all_models_responses(request)
|
responses = await get_all_models_responses(request, user=user)
|
||||||
|
|
||||||
def extract_data(response):
|
def extract_data(response):
|
||||||
if response and "data" in response:
|
if response and "data" in response:
|
||||||
|
|
@ -418,16 +441,14 @@ async def get_models(
|
||||||
}
|
}
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
|
||||||
)
|
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
|
|
@ -489,7 +510,7 @@ async def get_models(
|
||||||
raise HTTPException(status_code=500, detail=error_detail)
|
raise HTTPException(status_code=500, detail=error_detail)
|
||||||
|
|
||||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
models["data"] = get_filtered_models(models, user)
|
models["data"] = await get_filtered_models(models, user)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
@ -507,7 +528,7 @@ async def verify_connection(
|
||||||
key = form_data.key
|
key = form_data.key
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
|
|
@ -515,6 +536,16 @@ async def verify_connection(
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
|
|
@ -551,9 +582,9 @@ async def generate_chat_completion(
|
||||||
bypass_filter = True
|
bypass_filter = True
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
payload = {**form_data}
|
payload = {**form_data}
|
||||||
if "metadata" in payload:
|
metadata = payload.pop("metadata", None)
|
||||||
del payload["metadata"]
|
|
||||||
|
|
||||||
model_id = form_data.get("model")
|
model_id = form_data.get("model")
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
@ -566,7 +597,7 @@ async def generate_chat_completion(
|
||||||
|
|
||||||
params = model_info.params.model_dump()
|
params = model_info.params.model_dump()
|
||||||
payload = apply_model_params_to_body_openai(params, payload)
|
payload = apply_model_params_to_body_openai(params, payload)
|
||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if not bypass_filter and user.role == "user":
|
if not bypass_filter and user.role == "user":
|
||||||
|
|
@ -587,7 +618,7 @@ async def generate_chat_completion(
|
||||||
detail="Model not found",
|
detail="Model not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||||
if model:
|
if model:
|
||||||
idx = model["urlIdx"]
|
idx = model["urlIdx"]
|
||||||
|
|
@ -621,10 +652,10 @@ async def generate_chat_completion(
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
|
||||||
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||||
is_o1 = payload["model"].lower().startswith("o1-")
|
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||||
if is_o1:
|
if is_o1_o3:
|
||||||
payload = openai_o1_handler(payload)
|
payload = openai_o1_o3_handler(payload)
|
||||||
elif "api.openai.com" not in url:
|
elif "api.openai.com" not in url:
|
||||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||||
if "max_completion_tokens" in payload:
|
if "max_completion_tokens" in payload:
|
||||||
|
|
@ -777,7 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
if r is not None:
|
if r is not None:
|
||||||
try:
|
try:
|
||||||
res = await r.json()
|
res = await r.json()
|
||||||
print(res)
|
log.error(res)
|
||||||
if "error" in res:
|
if "error" in res:
|
||||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from fastapi import (
|
||||||
status,
|
status,
|
||||||
APIRouter,
|
APIRouter,
|
||||||
)
|
)
|
||||||
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
|
||||||
return sorted_filters
|
return sorted_filters
|
||||||
|
|
||||||
|
|
||||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
model_id = payload["model"]
|
model_id = payload["model"]
|
||||||
|
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
if "pipeline" in model:
|
if "pipeline" in model:
|
||||||
sorted_filters.append(model)
|
sorted_filters.append(model)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
for filter in sorted_filters:
|
for filter in sorted_filters:
|
||||||
r = None
|
urlIdx = filter.get("urlIdx")
|
||||||
try:
|
if urlIdx is None:
|
||||||
urlIdx = filter["urlIdx"]
|
continue
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
if key == "":
|
if not key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
r = requests.post(
|
request_data = {
|
||||||
f"{url}/{filter['id']}/filter/inlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": user,
|
"user": user,
|
||||||
"body": payload,
|
"body": payload,
|
||||||
},
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{url}/{filter['id']}/filter/inlet",
|
||||||
|
headers=headers,
|
||||||
|
json=request_data,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = await response.json()
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
res = (
|
||||||
|
await response.json()
|
||||||
|
if response.content_type == "application/json"
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
payload = r.json()
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
if "detail" in res:
|
||||||
raise Exception(r.status_code, res["detail"])
|
raise Exception(response.status, res["detail"])
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
model_id = payload["model"]
|
model_id = payload["model"]
|
||||||
|
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
if "pipeline" in model:
|
if "pipeline" in model:
|
||||||
sorted_filters = [model] + sorted_filters
|
sorted_filters = [model] + sorted_filters
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
for filter in sorted_filters:
|
for filter in sorted_filters:
|
||||||
r = None
|
urlIdx = filter.get("urlIdx")
|
||||||
try:
|
if urlIdx is None:
|
||||||
urlIdx = filter["urlIdx"]
|
continue
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
if key != "":
|
if not key:
|
||||||
r = requests.post(
|
continue
|
||||||
f"{url}/{filter['id']}/filter/outlet",
|
|
||||||
headers={"Authorization": f"Bearer {key}"},
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
json={
|
request_data = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"body": payload,
|
"body": payload,
|
||||||
},
|
}
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
payload = data
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
try:
|
try:
|
||||||
res = r.json()
|
async with session.post(
|
||||||
|
f"{url}/{filter['id']}/filter/outlet",
|
||||||
|
headers=headers,
|
||||||
|
json=request_data,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = await response.json()
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
try:
|
||||||
|
res = (
|
||||||
|
await response.json()
|
||||||
|
if "application/json" in response.content_type
|
||||||
|
else {}
|
||||||
|
)
|
||||||
if "detail" in res:
|
if "detail" in res:
|
||||||
return Exception(r.status_code, res)
|
raise Exception(response.status, res)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
except Exception as e:
|
||||||
else:
|
log.exception(f"Connection error: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
@ -161,7 +169,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/list")
|
@router.get("/list")
|
||||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||||
responses = await get_all_models_responses(request)
|
responses = await get_all_models_responses(request, user)
|
||||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||||
|
|
||||||
urlIdxs = [
|
urlIdxs = [
|
||||||
|
|
@ -188,7 +196,7 @@ async def upload_pipeline(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
print("upload_pipeline", urlIdx, file.filename)
|
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||||
# Check if the uploaded file is a python file
|
# Check if the uploaded file is a python file
|
||||||
if not (file.filename and file.filename.endswith(".py")):
|
if not (file.filename and file.filename.endswith(".py")):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -223,7 +231,7 @@ async def upload_pipeline(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
status_code = status.HTTP_404_NOT_FOUND
|
status_code = status.HTTP_404_NOT_FOUND
|
||||||
|
|
@ -274,7 +282,7 @@ async def add_pipeline(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
if r is not None:
|
if r is not None:
|
||||||
|
|
@ -319,7 +327,7 @@ async def delete_pipeline(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
if r is not None:
|
if r is not None:
|
||||||
|
|
@ -353,7 +361,7 @@ async def get_pipelines(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
if r is not None:
|
if r is not None:
|
||||||
|
|
@ -392,7 +400,7 @@ async def get_pipeline_valves(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
if r is not None:
|
if r is not None:
|
||||||
|
|
@ -432,7 +440,7 @@ async def get_pipeline_valves_spec(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
if r is not None:
|
if r is not None:
|
||||||
|
|
@ -474,7 +482,7 @@ async def update_pipeline_valves(
|
||||||
return {**data}
|
return {**data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
log.exception(f"Connection error: {e}")
|
||||||
|
|
||||||
detail = None
|
detail = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,11 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prompt.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
prompt.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", prompt.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
@ -45,17 +46,20 @@ from open_webui.retrieval.web.utils import get_web_loader
|
||||||
from open_webui.retrieval.web.brave import search_brave
|
from open_webui.retrieval.web.brave import search_brave
|
||||||
from open_webui.retrieval.web.kagi import search_kagi
|
from open_webui.retrieval.web.kagi import search_kagi
|
||||||
from open_webui.retrieval.web.mojeek import search_mojeek
|
from open_webui.retrieval.web.mojeek import search_mojeek
|
||||||
|
from open_webui.retrieval.web.bocha import search_bocha
|
||||||
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||||||
from open_webui.retrieval.web.google_pse import search_google_pse
|
from open_webui.retrieval.web.google_pse import search_google_pse
|
||||||
from open_webui.retrieval.web.jina_search import search_jina
|
from open_webui.retrieval.web.jina_search import search_jina
|
||||||
from open_webui.retrieval.web.searchapi import search_searchapi
|
from open_webui.retrieval.web.searchapi import search_searchapi
|
||||||
|
from open_webui.retrieval.web.serpapi import search_serpapi
|
||||||
from open_webui.retrieval.web.searxng import search_searxng
|
from open_webui.retrieval.web.searxng import search_searxng
|
||||||
from open_webui.retrieval.web.serper import search_serper
|
from open_webui.retrieval.web.serper import search_serper
|
||||||
from open_webui.retrieval.web.serply import search_serply
|
from open_webui.retrieval.web.serply import search_serply
|
||||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||||
from open_webui.retrieval.web.tavily import search_tavily
|
from open_webui.retrieval.web.tavily import search_tavily
|
||||||
from open_webui.retrieval.web.bing import search_bing
|
from open_webui.retrieval.web.bing import search_bing
|
||||||
|
from open_webui.retrieval.web.exa import search_exa
|
||||||
|
from open_webui.retrieval.web.perplexity import search_perplexity
|
||||||
|
|
||||||
from open_webui.retrieval.utils import (
|
from open_webui.retrieval.utils import (
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
|
|
@ -347,11 +351,18 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
|
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||||
|
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||||
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||||
"content_extraction": {
|
"content_extraction": {
|
||||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||||
|
"document_intelligence_config": {
|
||||||
|
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
|
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"chunk": {
|
"chunk": {
|
||||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||||
|
|
@ -368,10 +379,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||||
},
|
},
|
||||||
"web": {
|
"web": {
|
||||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
|
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||||
"search": {
|
"search": {
|
||||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||||
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||||
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
||||||
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||||
|
|
@ -379,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||||
|
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||||
|
|
@ -386,11 +400,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||||
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||||
|
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||||
|
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||||
|
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||||
|
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
|
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -401,10 +421,16 @@ class FileConfig(BaseModel):
|
||||||
max_count: Optional[int] = None
|
max_count: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentIntelligenceConfigForm(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
|
||||||
class ContentExtractionConfig(BaseModel):
|
class ContentExtractionConfig(BaseModel):
|
||||||
engine: str = ""
|
engine: str = ""
|
||||||
tika_server_url: Optional[str] = None
|
tika_server_url: Optional[str] = None
|
||||||
docling_server_url: Optional[str] = None
|
docling_server_url: Optional[str] = None
|
||||||
|
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||||
|
|
||||||
|
|
||||||
class ChunkParamUpdateForm(BaseModel):
|
class ChunkParamUpdateForm(BaseModel):
|
||||||
|
|
@ -428,6 +454,7 @@ class WebSearchConfig(BaseModel):
|
||||||
brave_search_api_key: Optional[str] = None
|
brave_search_api_key: Optional[str] = None
|
||||||
kagi_search_api_key: Optional[str] = None
|
kagi_search_api_key: Optional[str] = None
|
||||||
mojeek_search_api_key: Optional[str] = None
|
mojeek_search_api_key: Optional[str] = None
|
||||||
|
bocha_search_api_key: Optional[str] = None
|
||||||
serpstack_api_key: Optional[str] = None
|
serpstack_api_key: Optional[str] = None
|
||||||
serpstack_https: Optional[bool] = None
|
serpstack_https: Optional[bool] = None
|
||||||
serper_api_key: Optional[str] = None
|
serper_api_key: Optional[str] = None
|
||||||
|
|
@ -435,21 +462,31 @@ class WebSearchConfig(BaseModel):
|
||||||
tavily_api_key: Optional[str] = None
|
tavily_api_key: Optional[str] = None
|
||||||
searchapi_api_key: Optional[str] = None
|
searchapi_api_key: Optional[str] = None
|
||||||
searchapi_engine: Optional[str] = None
|
searchapi_engine: Optional[str] = None
|
||||||
|
serpapi_api_key: Optional[str] = None
|
||||||
|
serpapi_engine: Optional[str] = None
|
||||||
jina_api_key: Optional[str] = None
|
jina_api_key: Optional[str] = None
|
||||||
bing_search_v7_endpoint: Optional[str] = None
|
bing_search_v7_endpoint: Optional[str] = None
|
||||||
bing_search_v7_subscription_key: Optional[str] = None
|
bing_search_v7_subscription_key: Optional[str] = None
|
||||||
|
exa_api_key: Optional[str] = None
|
||||||
|
perplexity_api_key: Optional[str] = None
|
||||||
result_count: Optional[int] = None
|
result_count: Optional[int] = None
|
||||||
concurrent_requests: Optional[int] = None
|
concurrent_requests: Optional[int] = None
|
||||||
|
trust_env: Optional[bool] = None
|
||||||
|
domain_filter_list: Optional[List[str]] = []
|
||||||
|
|
||||||
|
|
||||||
class WebConfig(BaseModel):
|
class WebConfig(BaseModel):
|
||||||
search: WebSearchConfig
|
search: WebSearchConfig
|
||||||
web_loader_ssl_verification: Optional[bool] = None
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
||||||
|
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdateForm(BaseModel):
|
class ConfigUpdateForm(BaseModel):
|
||||||
|
RAG_FULL_CONTEXT: Optional[bool] = None
|
||||||
|
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||||
pdf_extract_images: Optional[bool] = None
|
pdf_extract_images: Optional[bool] = None
|
||||||
enable_google_drive_integration: Optional[bool] = None
|
enable_google_drive_integration: Optional[bool] = None
|
||||||
|
enable_onedrive_integration: Optional[bool] = None
|
||||||
file: Optional[FileConfig] = None
|
file: Optional[FileConfig] = None
|
||||||
content_extraction: Optional[ContentExtractionConfig] = None
|
content_extraction: Optional[ContentExtractionConfig] = None
|
||||||
chunk: Optional[ChunkParamUpdateForm] = None
|
chunk: Optional[ChunkParamUpdateForm] = None
|
||||||
|
|
@ -467,18 +504,38 @@ async def update_rag_config(
|
||||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.RAG_FULL_CONTEXT = (
|
||||||
|
form_data.RAG_FULL_CONTEXT
|
||||||
|
if form_data.RAG_FULL_CONTEXT is not None
|
||||||
|
else request.app.state.config.RAG_FULL_CONTEXT
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
||||||
|
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
|
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
||||||
|
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||||
form_data.enable_google_drive_integration
|
form_data.enable_google_drive_integration
|
||||||
if form_data.enable_google_drive_integration is not None
|
if form_data.enable_google_drive_integration is not None
|
||||||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
||||||
|
form_data.enable_onedrive_integration
|
||||||
|
if form_data.enable_onedrive_integration is not None
|
||||||
|
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
||||||
|
)
|
||||||
|
|
||||||
if form_data.file is not None:
|
if form_data.file is not None:
|
||||||
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
||||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||||
|
|
||||||
if form_data.content_extraction is not None:
|
if form_data.content_extraction is not None:
|
||||||
log.info(f"Updating text settings: {form_data.content_extraction}")
|
log.info(
|
||||||
|
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
|
||||||
|
)
|
||||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||||
form_data.content_extraction.engine
|
form_data.content_extraction.engine
|
||||||
)
|
)
|
||||||
|
|
@ -488,6 +545,13 @@ async def update_rag_config(
|
||||||
request.app.state.config.DOCLING_SERVER_URL = (
|
request.app.state.config.DOCLING_SERVER_URL = (
|
||||||
form_data.content_extraction.docling_server_url
|
form_data.content_extraction.docling_server_url
|
||||||
)
|
)
|
||||||
|
if form_data.content_extraction.document_intelligence_config is not None:
|
||||||
|
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||||
|
form_data.content_extraction.document_intelligence_config.endpoint
|
||||||
|
)
|
||||||
|
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||||||
|
form_data.content_extraction.document_intelligence_config.key
|
||||||
|
)
|
||||||
|
|
||||||
if form_data.chunk is not None:
|
if form_data.chunk is not None:
|
||||||
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||||
|
|
@ -502,11 +566,16 @@ async def update_rag_config(
|
||||||
if form_data.web is not None:
|
if form_data.web is not None:
|
||||||
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||||
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
||||||
form_data.web.web_loader_ssl_verification
|
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||||
)
|
)
|
||||||
|
|
||||||
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
||||||
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
||||||
|
|
||||||
|
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||||
|
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.SEARXNG_QUERY_URL = (
|
request.app.state.config.SEARXNG_QUERY_URL = (
|
||||||
form_data.web.search.searxng_query_url
|
form_data.web.search.searxng_query_url
|
||||||
)
|
)
|
||||||
|
|
@ -525,6 +594,9 @@ async def update_rag_config(
|
||||||
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
||||||
form_data.web.search.mojeek_search_api_key
|
form_data.web.search.mojeek_search_api_key
|
||||||
)
|
)
|
||||||
|
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
||||||
|
form_data.web.search.bocha_search_api_key
|
||||||
|
)
|
||||||
request.app.state.config.SERPSTACK_API_KEY = (
|
request.app.state.config.SERPSTACK_API_KEY = (
|
||||||
form_data.web.search.serpstack_api_key
|
form_data.web.search.serpstack_api_key
|
||||||
)
|
)
|
||||||
|
|
@ -539,6 +611,9 @@ async def update_rag_config(
|
||||||
form_data.web.search.searchapi_engine
|
form_data.web.search.searchapi_engine
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
|
||||||
|
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
|
||||||
|
|
||||||
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
||||||
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
||||||
form_data.web.search.bing_search_v7_endpoint
|
form_data.web.search.bing_search_v7_endpoint
|
||||||
|
|
@ -547,16 +622,30 @@ async def update_rag_config(
|
||||||
form_data.web.search.bing_search_v7_subscription_key
|
form_data.web.search.bing_search_v7_subscription_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
|
||||||
|
|
||||||
|
request.app.state.config.PERPLEXITY_API_KEY = (
|
||||||
|
form_data.web.search.perplexity_api_key
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
|
||||||
form_data.web.search.result_count
|
form_data.web.search.result_count
|
||||||
)
|
)
|
||||||
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||||
form_data.web.search.concurrent_requests
|
form_data.web.search.concurrent_requests
|
||||||
)
|
)
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
|
||||||
|
form_data.web.search.trust_env
|
||||||
|
)
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||||
|
form_data.web.search.domain_filter_list
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
|
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||||
|
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||||
"file": {
|
"file": {
|
||||||
"max_size": request.app.state.config.FILE_MAX_SIZE,
|
"max_size": request.app.state.config.FILE_MAX_SIZE,
|
||||||
"max_count": request.app.state.config.FILE_MAX_COUNT,
|
"max_count": request.app.state.config.FILE_MAX_COUNT,
|
||||||
|
|
@ -565,6 +654,10 @@ async def update_rag_config(
|
||||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||||
|
"document_intelligence_config": {
|
||||||
|
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
|
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"chunk": {
|
"chunk": {
|
||||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||||
|
|
@ -577,7 +670,8 @@ async def update_rag_config(
|
||||||
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||||
},
|
},
|
||||||
"web": {
|
"web": {
|
||||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
|
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||||
"search": {
|
"search": {
|
||||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||||
|
|
@ -587,18 +681,25 @@ async def update_rag_config(
|
||||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||||
|
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||||
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
|
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
|
||||||
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||||
|
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||||
|
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||||
|
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||||
|
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
|
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||||
|
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -666,6 +767,7 @@ def save_docs_to_vector_db(
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
split: bool = True,
|
split: bool = True,
|
||||||
add: bool = False,
|
add: bool = False,
|
||||||
|
user=None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
def _get_docs_info(docs: list[Document]) -> str:
|
def _get_docs_info(docs: list[Document]) -> str:
|
||||||
docs_info = set()
|
docs_info = set()
|
||||||
|
|
@ -746,7 +848,11 @@ def save_docs_to_vector_db(
|
||||||
# for meta-data so convert them to string.
|
# for meta-data so convert them to string.
|
||||||
for metadata in metadatas:
|
for metadata in metadatas:
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
if isinstance(value, datetime):
|
if (
|
||||||
|
isinstance(value, datetime)
|
||||||
|
or isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
):
|
||||||
metadata[key] = str(value)
|
metadata[key] = str(value)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -781,7 +887,7 @@ def save_docs_to_vector_db(
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = embedding_function(
|
embeddings = embedding_function(
|
||||||
list(map(lambda x: x.replace("\n", " "), texts))
|
list(map(lambda x: x.replace("\n", " "), texts)), user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
items = [
|
items = [
|
||||||
|
|
@ -829,7 +935,12 @@ def process_file(
|
||||||
# Update the content in the file
|
# Update the content in the file
|
||||||
# Usage: /files/{file_id}/data/content/update
|
# Usage: /files/{file_id}/data/content/update
|
||||||
|
|
||||||
|
try:
|
||||||
|
# /files/{file_id}/data/content/update
|
||||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||||
|
except:
|
||||||
|
# Audio file upload pipeline
|
||||||
|
pass
|
||||||
|
|
||||||
docs = [
|
docs = [
|
||||||
Document(
|
Document(
|
||||||
|
|
@ -887,6 +998,8 @@ def process_file(
|
||||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
|
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
|
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
)
|
)
|
||||||
docs = loader.load(
|
docs = loader.load(
|
||||||
file.filename, file.meta.get("content_type"), file_path
|
file.filename, file.meta.get("content_type"), file_path
|
||||||
|
|
@ -929,6 +1042,7 @@ def process_file(
|
||||||
hash = calculate_sha256_string(text_content)
|
hash = calculate_sha256_string(text_content)
|
||||||
Files.update_file_hash_by_id(file.id, hash)
|
Files.update_file_hash_by_id(file.id, hash)
|
||||||
|
|
||||||
|
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||||
try:
|
try:
|
||||||
result = save_docs_to_vector_db(
|
result = save_docs_to_vector_db(
|
||||||
request,
|
request,
|
||||||
|
|
@ -940,6 +1054,7 @@ def process_file(
|
||||||
"hash": hash,
|
"hash": hash,
|
||||||
},
|
},
|
||||||
add=(True if form_data.collection_name else False),
|
add=(True if form_data.collection_name else False),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|
@ -958,6 +1073,14 @@ def process_file(
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"collection_name": None,
|
||||||
|
"filename": file.filename,
|
||||||
|
"content": text_content,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
if "No pandoc was found" in str(e):
|
if "No pandoc was found" in str(e):
|
||||||
|
|
@ -997,7 +1120,7 @@ def process_text(
|
||||||
text_content = form_data.content
|
text_content = form_data.content
|
||||||
log.debug(f"text_content: {text_content}")
|
log.debug(f"text_content: {text_content}")
|
||||||
|
|
||||||
result = save_docs_to_vector_db(request, docs, collection_name)
|
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||||||
if result:
|
if result:
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
|
@ -1030,7 +1153,9 @@ def process_youtube_video(
|
||||||
content = " ".join([doc.page_content for doc in docs])
|
content = " ".join([doc.page_content for doc in docs])
|
||||||
log.debug(f"text_content: {content}")
|
log.debug(f"text_content: {content}")
|
||||||
|
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
save_docs_to_vector_db(
|
||||||
|
request, docs, collection_name, overwrite=True, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
|
@ -1071,7 +1196,13 @@ def process_web(
|
||||||
content = " ".join([doc.page_content for doc in docs])
|
content = " ".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
log.debug(f"text_content: {content}")
|
log.debug(f"text_content: {content}")
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
|
||||||
|
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||||
|
save_docs_to_vector_db(
|
||||||
|
request, docs, collection_name, overwrite=True, user=user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
collection_name = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
|
@ -1083,6 +1214,7 @@ def process_web(
|
||||||
},
|
},
|
||||||
"meta": {
|
"meta": {
|
||||||
"name": form_data.url,
|
"name": form_data.url,
|
||||||
|
"source": form_data.url,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -1102,11 +1234,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
- BRAVE_SEARCH_API_KEY
|
- BRAVE_SEARCH_API_KEY
|
||||||
- KAGI_SEARCH_API_KEY
|
- KAGI_SEARCH_API_KEY
|
||||||
- MOJEEK_SEARCH_API_KEY
|
- MOJEEK_SEARCH_API_KEY
|
||||||
|
- BOCHA_SEARCH_API_KEY
|
||||||
- SERPSTACK_API_KEY
|
- SERPSTACK_API_KEY
|
||||||
- SERPER_API_KEY
|
- SERPER_API_KEY
|
||||||
- SERPLY_API_KEY
|
- SERPLY_API_KEY
|
||||||
- TAVILY_API_KEY
|
- TAVILY_API_KEY
|
||||||
|
- EXA_API_KEY
|
||||||
|
- PERPLEXITY_API_KEY
|
||||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||||
|
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||||||
Args:
|
Args:
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
"""
|
"""
|
||||||
|
|
@ -1168,6 +1304,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
||||||
|
elif engine == "bocha":
|
||||||
|
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||||||
|
return search_bocha(
|
||||||
|
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||||
|
query,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
||||||
elif engine == "serpstack":
|
elif engine == "serpstack":
|
||||||
if request.app.state.config.SERPSTACK_API_KEY:
|
if request.app.state.config.SERPSTACK_API_KEY:
|
||||||
return search_serpstack(
|
return search_serpstack(
|
||||||
|
|
@ -1211,6 +1357,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
request.app.state.config.TAVILY_API_KEY,
|
request.app.state.config.TAVILY_API_KEY,
|
||||||
query,
|
query,
|
||||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||||
|
|
@ -1225,6 +1372,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
||||||
|
elif engine == "serpapi":
|
||||||
|
if request.app.state.config.SERPAPI_API_KEY:
|
||||||
|
return search_serpapi(
|
||||||
|
request.app.state.config.SERPAPI_API_KEY,
|
||||||
|
request.app.state.config.SERPAPI_ENGINE,
|
||||||
|
query,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
||||||
elif engine == "jina":
|
elif engine == "jina":
|
||||||
return search_jina(
|
return search_jina(
|
||||||
request.app.state.config.JINA_API_KEY,
|
request.app.state.config.JINA_API_KEY,
|
||||||
|
|
@ -1240,12 +1398,26 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
)
|
)
|
||||||
|
elif engine == "exa":
|
||||||
|
return search_exa(
|
||||||
|
request.app.state.config.EXA_API_KEY,
|
||||||
|
query,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
)
|
||||||
|
elif engine == "perplexity":
|
||||||
|
return search_perplexity(
|
||||||
|
request.app.state.config.PERPLEXITY_API_KEY,
|
||||||
|
query,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No search engine API key found in environment variables")
|
raise Exception("No search engine API key found in environment variables")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/web/search")
|
@router.post("/process/web/search")
|
||||||
def process_web_search(
|
async def process_web_search(
|
||||||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
@ -1277,14 +1449,39 @@ def process_web_search(
|
||||||
urls,
|
urls,
|
||||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
|
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||||
|
)
|
||||||
|
docs = await loader.aload()
|
||||||
|
|
||||||
|
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"collection_name": None,
|
||||||
|
"filenames": urls,
|
||||||
|
"docs": [
|
||||||
|
{
|
||||||
|
"content": doc.page_content,
|
||||||
|
"metadata": doc.metadata,
|
||||||
|
}
|
||||||
|
for doc in docs
|
||||||
|
],
|
||||||
|
"loaded_count": len(docs),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
await run_in_threadpool(
|
||||||
|
save_docs_to_vector_db,
|
||||||
|
request,
|
||||||
|
docs,
|
||||||
|
collection_name,
|
||||||
|
overwrite=True,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
docs = loader.load()
|
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"filenames": urls,
|
"filenames": urls,
|
||||||
|
"loaded_count": len(docs),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -1313,7 +1510,9 @@ def query_doc_handler(
|
||||||
return query_doc_with_hybrid_search(
|
return query_doc_with_hybrid_search(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
query, user=user
|
||||||
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
|
|
@ -1321,12 +1520,16 @@ def query_doc_handler(
|
||||||
if form_data.r
|
if form_data.r
|
||||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||||
),
|
),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_doc(
|
return query_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
|
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
form_data.query, user=user
|
||||||
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -1355,7 +1558,9 @@ def query_collection_handler(
|
||||||
return query_collection_with_hybrid_search(
|
return query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
query, user=user
|
||||||
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
|
|
@ -1368,7 +1573,9 @@ def query_collection_handler(
|
||||||
return query_collection(
|
return query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
query, user=user
|
||||||
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1432,11 +1639,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||||||
elif os.path.isdir(file_path):
|
elif os.path.isdir(file_path):
|
||||||
shutil.rmtree(file_path) # Remove the directory
|
shutil.rmtree(file_path) # Remove the directory
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||||
else:
|
else:
|
||||||
print(f"The directory {folder} does not exist")
|
log.warning(f"The directory {folder} does not exist")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1516,6 +1723,7 @@ def process_files_batch(
|
||||||
docs=all_docs,
|
docs=all_docs,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
add=True,
|
add=True,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update all files with collection name
|
# Update all files with collection name
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
|
|
@ -19,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.constants import TASKS
|
from open_webui.constants import TASKS
|
||||||
|
|
||||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||||
|
from open_webui.utils.filter import (
|
||||||
|
get_sorted_filter_ids,
|
||||||
|
process_filter_functions,
|
||||||
|
)
|
||||||
from open_webui.utils.task import get_task_model_id
|
from open_webui.utils.task import get_task_model_id
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
|
|
@ -57,6 +62,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||||
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
|
@ -67,6 +73,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||||
class TaskConfigForm(BaseModel):
|
class TaskConfigForm(BaseModel):
|
||||||
TASK_MODEL: Optional[str]
|
TASK_MODEL: Optional[str]
|
||||||
TASK_MODEL_EXTERNAL: Optional[str]
|
TASK_MODEL_EXTERNAL: Optional[str]
|
||||||
|
ENABLE_TITLE_GENERATION: bool
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
||||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||||
|
|
@ -85,10 +92,15 @@ async def update_task_config(
|
||||||
):
|
):
|
||||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||||
|
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
||||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||||
|
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||||
)
|
)
|
||||||
|
|
@ -117,6 +129,7 @@ async def update_task_config(
|
||||||
return {
|
return {
|
||||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||||
|
|
@ -134,6 +147,18 @@ async def update_task_config(
|
||||||
async def generate_title(
|
async def generate_title(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"detail": "Title generation is disabled"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -161,9 +186,20 @@ async def generate_title(
|
||||||
else:
|
else:
|
||||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
messages = form_data["messages"]
|
||||||
|
|
||||||
|
# Remove reasoning details from the messages
|
||||||
|
for message in messages:
|
||||||
|
message["content"] = re.sub(
|
||||||
|
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||||
|
"",
|
||||||
|
message["content"],
|
||||||
|
flags=re.S,
|
||||||
|
).strip()
|
||||||
|
|
||||||
content = title_generation_template(
|
content = title_generation_template(
|
||||||
template,
|
template,
|
||||||
form_data["messages"],
|
messages,
|
||||||
{
|
{
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"location": user.info.get("location") if user.info else None,
|
"location": user.info.get("location") if user.info else None,
|
||||||
|
|
@ -175,19 +211,26 @@ async def generate_title(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
**(
|
**(
|
||||||
{"max_tokens": 50}
|
{"max_tokens": 1000}
|
||||||
if models[task_model_id]["owned_by"] == "ollama"
|
if models[task_model_id].get("owned_by") == "ollama"
|
||||||
else {
|
else {
|
||||||
"max_completion_tokens": 50,
|
"max_completion_tokens": 1000,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"task": str(TASKS.TITLE_GENERATION),
|
"task": str(TASKS.TITLE_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -209,6 +252,11 @@ async def generate_chat_tags(
|
||||||
content={"detail": "Tags generation is disabled"},
|
content={"detail": "Tags generation is disabled"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -245,12 +293,19 @@ async def generate_chat_tags(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"task": str(TASKS.TAGS_GENERATION),
|
"task": str(TASKS.TAGS_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -265,6 +320,11 @@ async def generate_chat_tags(
|
||||||
async def generate_image_prompt(
|
async def generate_image_prompt(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -305,12 +365,19 @@ async def generate_image_prompt(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -340,6 +407,11 @@ async def generate_queries(
|
||||||
detail=f"Query generation is disabled",
|
detail=f"Query generation is disabled",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -376,12 +448,19 @@ async def generate_queries(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"task": str(TASKS.QUERY_GENERATION),
|
"task": str(TASKS.QUERY_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -415,6 +494,11 @@ async def generate_autocompletion(
|
||||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -451,12 +535,19 @@ async def generate_autocompletion(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -472,6 +563,11 @@ async def generate_emoji(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
@ -509,15 +605,25 @@ async def generate_emoji(
|
||||||
"stream": False,
|
"stream": False,
|
||||||
**(
|
**(
|
||||||
{"max_tokens": 4}
|
{"max_tokens": 4}
|
||||||
if models[task_model_id]["owned_by"] == "ollama"
|
if models[task_model_id].get("owned_by") == "ollama"
|
||||||
else {
|
else {
|
||||||
"max_completion_tokens": 4,
|
"max_completion_tokens": 4,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
|
"task": str(TASKS.EMOJI_GENERATION),
|
||||||
|
"task_body": form_data,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -532,7 +638,13 @@ async def generate_moa_response(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
|
|
@ -565,12 +677,19 @@ async def generate_moa_response(
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": form_data.get("stream", False),
|
"stream": form_data.get("stream", False),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||||
"task_body": form_data,
|
"task_body": form_data,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Process the payload through the pipeline
|
||||||
|
try:
|
||||||
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from open_webui.utils.tools import get_tools_specs
|
from open_webui.utils.tools import get_tools_specs
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -100,7 +105,7 @@ async def create_new_tools(
|
||||||
specs = get_tools_specs(TOOLS[form_data.id])
|
specs = get_tools_specs(TOOLS[form_data.id])
|
||||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||||
|
|
||||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|
@ -111,7 +116,7 @@ async def create_new_tools(
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||||
|
|
@ -193,7 +198,7 @@ async def update_tools_by_id(
|
||||||
"specs": specs,
|
"specs": specs,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(updated)
|
log.debug(updated)
|
||||||
tools = Tools.update_tool_by_id(id, updated)
|
tools = Tools.update_tool_by_id(id, updated)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|
@ -227,7 +232,11 @@ async def delete_tools_by_id(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
if tools.user_id != user.id and user.role != "admin":
|
if (
|
||||||
|
tools.user_id != user.id
|
||||||
|
and not has_access(user.id, "write", tools.access_control)
|
||||||
|
and user.role != "admin"
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
|
@ -339,7 +348,7 @@ async def update_tools_valves_by_id(
|
||||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||||
return valves.model_dump()
|
return valves.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||||
|
|
@ -417,7 +426,7 @@ async def update_tools_user_valves_by_id(
|
||||||
)
|
)
|
||||||
return user_valves.model_dump()
|
return user_valves.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
|
||||||
class FeaturesPermissions(BaseModel):
|
class FeaturesPermissions(BaseModel):
|
||||||
web_search: bool = True
|
web_search: bool = True
|
||||||
image_generation: bool = True
|
image_generation: bool = True
|
||||||
|
code_interpreter: bool = True
|
||||||
|
|
||||||
|
|
||||||
class UserPermissions(BaseModel):
|
class UserPermissions(BaseModel):
|
||||||
|
|
@ -152,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||||
async def update_user_settings_by_session_user(
|
async def update_user_settings_by_session_user(
|
||||||
form_data: UserSettings, user=Depends(get_verified_user)
|
form_data: UserSettings, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
|
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||||
if user:
|
if user:
|
||||||
return user.settings
|
return user.settings
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,84 @@
|
||||||
import black
|
import black
|
||||||
|
import logging
|
||||||
import markdown
|
import markdown
|
||||||
|
|
||||||
from open_webui.models.chats import ChatTitleMessagesForm
|
from open_webui.models.chats import ChatTitleMessagesForm
|
||||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.responses import FileResponse
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.misc import get_gravatar_url
|
from open_webui.utils.misc import get_gravatar_url
|
||||||
from open_webui.utils.pdf_generator import PDFGenerator
|
from open_webui.utils.pdf_generator import PDFGenerator
|
||||||
from open_webui.utils.auth import get_admin_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gravatar")
|
@router.get("/gravatar")
|
||||||
async def get_gravatar(
|
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||||
email: str,
|
|
||||||
):
|
|
||||||
return get_gravatar_url(email)
|
return get_gravatar_url(email)
|
||||||
|
|
||||||
|
|
||||||
class CodeFormatRequest(BaseModel):
|
class CodeForm(BaseModel):
|
||||||
code: str
|
code: str
|
||||||
|
|
||||||
|
|
||||||
@router.post("/code/format")
|
@router.post("/code/format")
|
||||||
async def format_code(request: CodeFormatRequest):
|
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||||
try:
|
try:
|
||||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||||
return {"code": formatted_code}
|
return {"code": formatted_code}
|
||||||
except black.NothingChanged:
|
except black.NothingChanged:
|
||||||
return {"code": request.code}
|
return {"code": form_data.code}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/code/execute")
|
||||||
|
async def execute_code(
|
||||||
|
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||||
|
output = await execute_code_jupyter(
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||||
|
form_data.code,
|
||||||
|
(
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||||
|
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||||
|
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Code execution engine not supported",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownForm(BaseModel):
|
class MarkdownForm(BaseModel):
|
||||||
md: str
|
md: str
|
||||||
|
|
||||||
|
|
||||||
@router.post("/markdown")
|
@router.post("/markdown")
|
||||||
async def get_html_from_markdown(
|
async def get_html_from_markdown(
|
||||||
form_data: MarkdownForm,
|
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
return {"html": markdown.markdown(form_data.md)}
|
return {"html": markdown.markdown(form_data.md)}
|
||||||
|
|
||||||
|
|
@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/pdf")
|
@router.post("/pdf")
|
||||||
async def download_chat_as_pdf(
|
async def download_chat_as_pdf(
|
||||||
form_data: ChatTitleMessagesForm,
|
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||||
|
|
@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
||||||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Error generating PDF: {e}")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from open_webui.env import (
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
WEBSOCKET_MANAGER,
|
WEBSOCKET_MANAGER,
|
||||||
WEBSOCKET_REDIS_URL,
|
WEBSOCKET_REDIS_URL,
|
||||||
|
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||||
)
|
)
|
||||||
from open_webui.utils.auth import decode_token
|
from open_webui.utils.auth import decode_token
|
||||||
from open_webui.socket.utils import RedisDict, RedisLock
|
from open_webui.socket.utils import RedisDict, RedisLock
|
||||||
|
|
@ -61,7 +62,7 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
clean_up_lock = RedisLock(
|
clean_up_lock = RedisLock(
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
lock_name="usage_cleanup_lock",
|
lock_name="usage_cleanup_lock",
|
||||||
timeout_secs=TIMEOUT_DURATION * 2,
|
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||||
)
|
)
|
||||||
aquire_func = clean_up_lock.aquire_lock
|
aquire_func = clean_up_lock.aquire_lock
|
||||||
renew_func = clean_up_lock.renew_lock
|
renew_func = clean_up_lock.renew_lock
|
||||||
|
|
@ -279,8 +280,8 @@ def get_event_emitter(request_info):
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"chat-events",
|
"chat-events",
|
||||||
{
|
{
|
||||||
"chat_id": request_info["chat_id"],
|
"chat_id": request_info.get("chat_id", None),
|
||||||
"message_id": request_info["message_id"],
|
"message_id": request_info.get("message_id", None),
|
||||||
"data": event_data,
|
"data": event_data,
|
||||||
},
|
},
|
||||||
to=session_id,
|
to=session_id,
|
||||||
|
|
@ -325,19 +326,22 @@ def get_event_emitter(request_info):
|
||||||
|
|
||||||
|
|
||||||
def get_event_call(request_info):
|
def get_event_call(request_info):
|
||||||
async def __event_call__(event_data):
|
async def __event_caller__(event_data):
|
||||||
response = await sio.call(
|
response = await sio.call(
|
||||||
"chat-events",
|
"chat-events",
|
||||||
{
|
{
|
||||||
"chat_id": request_info["chat_id"],
|
"chat_id": request_info.get("chat_id", None),
|
||||||
"message_id": request_info["message_id"],
|
"message_id": request_info.get("message_id", None),
|
||||||
"data": event_data,
|
"data": event_data,
|
||||||
},
|
},
|
||||||
to=request_info["session_id"],
|
to=request_info["session_id"],
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return __event_call__
|
return __event_caller__
|
||||||
|
|
||||||
|
|
||||||
|
get_event_caller = get_event_call
|
||||||
|
|
||||||
|
|
||||||
def get_user_id_from_session_pool(sid):
|
def get_user_id_from_session_pool(sid):
|
||||||
|
|
|
||||||
BIN
backend/open_webui/static/apple-touch-icon.png
Normal file
|
After Width: | Height: | Size: 7.3 KiB |
BIN
backend/open_webui/static/favicon-96x96.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
backend/open_webui/static/favicon-dark.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
backend/open_webui/static/favicon.ico
Normal file
|
After Width: | Height: | Size: 15 KiB |
3
backend/open_webui/static/favicon.svg
Normal file
|
After Width: | Height: | Size: 14 KiB |
0
backend/open_webui/static/loader.js
Normal file
21
backend/open_webui/static/site.webmanifest
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"name": "Open WebUI",
|
||||||
|
"short_name": "WebUI",
|
||||||
|
"icons": [
|
||||||
|
{
|
||||||
|
"src": "/static/web-app-manifest-192x192.png",
|
||||||
|
"sizes": "192x192",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"src": "/static/web-app-manifest-512x512.png",
|
||||||
|
"sizes": "512x512",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"theme_color": "#ffffff",
|
||||||
|
"background_color": "#ffffff",
|
||||||
|
"display": "standalone"
|
||||||
|
}
|
||||||
BIN
backend/open_webui/static/splash-dark.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
|
|
@ -9308,5 +9308,3 @@
|
||||||
.json-schema-2020-12__title:first-of-type {
|
.json-schema-2020-12__title:first-of-type {
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*# sourceMappingURL=swagger-ui.css.map*/
|
|
||||||
|
|
|
||||||
BIN
backend/open_webui/static/web-app-manifest-192x192.png
Normal file
|
After Width: | Height: | Size: 8.2 KiB |
BIN
backend/open_webui/static/web-app-manifest-512x512.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
|
|
@ -1,25 +1,41 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import BinaryIO, Tuple
|
from typing import BinaryIO, Tuple
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
from botocore.config import Config
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
S3_ACCESS_KEY_ID,
|
S3_ACCESS_KEY_ID,
|
||||||
S3_BUCKET_NAME,
|
S3_BUCKET_NAME,
|
||||||
S3_ENDPOINT_URL,
|
S3_ENDPOINT_URL,
|
||||||
|
S3_KEY_PREFIX,
|
||||||
S3_REGION_NAME,
|
S3_REGION_NAME,
|
||||||
S3_SECRET_ACCESS_KEY,
|
S3_SECRET_ACCESS_KEY,
|
||||||
|
S3_USE_ACCELERATE_ENDPOINT,
|
||||||
|
S3_ADDRESSING_STYLE,
|
||||||
GCS_BUCKET_NAME,
|
GCS_BUCKET_NAME,
|
||||||
GOOGLE_APPLICATION_CREDENTIALS_JSON,
|
GOOGLE_APPLICATION_CREDENTIALS_JSON,
|
||||||
|
AZURE_STORAGE_ENDPOINT,
|
||||||
|
AZURE_STORAGE_CONTAINER_NAME,
|
||||||
|
AZURE_STORAGE_KEY,
|
||||||
STORAGE_PROVIDER,
|
STORAGE_PROVIDER,
|
||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
)
|
)
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
from google.cloud.exceptions import GoogleCloudError, NotFound
|
from google.cloud.exceptions import GoogleCloudError, NotFound
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
from azure.storage.blob import BlobServiceClient
|
||||||
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
class StorageProvider(ABC):
|
class StorageProvider(ABC):
|
||||||
|
|
@ -64,7 +80,7 @@ class LocalStorageProvider(StorageProvider):
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
else:
|
else:
|
||||||
print(f"File {file_path} not found in local storage.")
|
log.warning(f"File {file_path} not found in local storage.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_all_files() -> None:
|
def delete_all_files() -> None:
|
||||||
|
|
@ -78,30 +94,52 @@ class LocalStorageProvider(StorageProvider):
|
||||||
elif os.path.isdir(file_path):
|
elif os.path.isdir(file_path):
|
||||||
shutil.rmtree(file_path) # Remove the directory
|
shutil.rmtree(file_path) # Remove the directory
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||||
else:
|
else:
|
||||||
print(f"Directory {UPLOAD_DIR} not found in local storage.")
|
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
|
||||||
|
|
||||||
|
|
||||||
class S3StorageProvider(StorageProvider):
|
class S3StorageProvider(StorageProvider):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
config = Config(
|
||||||
|
s3={
|
||||||
|
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
|
||||||
|
"addressing_style": S3_ADDRESSING_STYLE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If access key and secret are provided, use them for authentication
|
||||||
|
if S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY:
|
||||||
self.s3_client = boto3.client(
|
self.s3_client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
region_name=S3_REGION_NAME,
|
region_name=S3_REGION_NAME,
|
||||||
endpoint_url=S3_ENDPOINT_URL,
|
endpoint_url=S3_ENDPOINT_URL,
|
||||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# If no explicit credentials are provided, fall back to default AWS credentials
|
||||||
|
# This supports workload identity (IAM roles for EC2, EKS, etc.)
|
||||||
|
self.s3_client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=S3_REGION_NAME,
|
||||||
|
endpoint_url=S3_ENDPOINT_URL,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
self.bucket_name = S3_BUCKET_NAME
|
self.bucket_name = S3_BUCKET_NAME
|
||||||
|
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||||
|
|
||||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||||
"""Handles uploading of the file to S3 storage."""
|
"""Handles uploading of the file to S3 storage."""
|
||||||
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||||
try:
|
try:
|
||||||
self.s3_client.upload_file(file_path, self.bucket_name, filename)
|
s3_key = os.path.join(self.key_prefix, filename)
|
||||||
|
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||||
return (
|
return (
|
||||||
open(file_path, "rb").read(),
|
open(file_path, "rb").read(),
|
||||||
"s3://" + self.bucket_name + "/" + filename,
|
"s3://" + self.bucket_name + "/" + s3_key,
|
||||||
)
|
)
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error uploading file to S3: {e}")
|
raise RuntimeError(f"Error uploading file to S3: {e}")
|
||||||
|
|
@ -109,18 +147,18 @@ class S3StorageProvider(StorageProvider):
|
||||||
def get_file(self, file_path: str) -> str:
|
def get_file(self, file_path: str) -> str:
|
||||||
"""Handles downloading of the file from S3 storage."""
|
"""Handles downloading of the file from S3 storage."""
|
||||||
try:
|
try:
|
||||||
bucket_name, key = file_path.split("//")[1].split("/")
|
s3_key = self._extract_s3_key(file_path)
|
||||||
local_file_path = f"{UPLOAD_DIR}/{key}"
|
local_file_path = self._get_local_file_path(s3_key)
|
||||||
self.s3_client.download_file(bucket_name, key, local_file_path)
|
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
|
||||||
return local_file_path
|
return local_file_path
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error downloading file from S3: {e}")
|
raise RuntimeError(f"Error downloading file from S3: {e}")
|
||||||
|
|
||||||
def delete_file(self, file_path: str) -> None:
|
def delete_file(self, file_path: str) -> None:
|
||||||
"""Handles deletion of the file from S3 storage."""
|
"""Handles deletion of the file from S3 storage."""
|
||||||
filename = file_path.split("/")[-1]
|
|
||||||
try:
|
try:
|
||||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
|
s3_key = self._extract_s3_key(file_path)
|
||||||
|
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error deleting file from S3: {e}")
|
raise RuntimeError(f"Error deleting file from S3: {e}")
|
||||||
|
|
||||||
|
|
@ -133,6 +171,10 @@ class S3StorageProvider(StorageProvider):
|
||||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
||||||
if "Contents" in response:
|
if "Contents" in response:
|
||||||
for content in response["Contents"]:
|
for content in response["Contents"]:
|
||||||
|
# Skip objects that were not uploaded from open-webui in the first place
|
||||||
|
if not content["Key"].startswith(self.key_prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
self.s3_client.delete_object(
|
self.s3_client.delete_object(
|
||||||
Bucket=self.bucket_name, Key=content["Key"]
|
Bucket=self.bucket_name, Key=content["Key"]
|
||||||
)
|
)
|
||||||
|
|
@ -142,6 +184,13 @@ class S3StorageProvider(StorageProvider):
|
||||||
# Always delete from local storage
|
# Always delete from local storage
|
||||||
LocalStorageProvider.delete_all_files()
|
LocalStorageProvider.delete_all_files()
|
||||||
|
|
||||||
|
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
|
||||||
|
def _extract_s3_key(self, full_file_path: str) -> str:
|
||||||
|
return "/".join(full_file_path.split("//")[1].split("/")[1:])
|
||||||
|
|
||||||
|
def _get_local_file_path(self, s3_key: str) -> str:
|
||||||
|
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
|
||||||
|
|
||||||
|
|
||||||
class GCSStorageProvider(StorageProvider):
|
class GCSStorageProvider(StorageProvider):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -207,6 +256,74 @@ class GCSStorageProvider(StorageProvider):
|
||||||
LocalStorageProvider.delete_all_files()
|
LocalStorageProvider.delete_all_files()
|
||||||
|
|
||||||
|
|
||||||
|
class AzureStorageProvider(StorageProvider):
|
||||||
|
def __init__(self):
|
||||||
|
self.endpoint = AZURE_STORAGE_ENDPOINT
|
||||||
|
self.container_name = AZURE_STORAGE_CONTAINER_NAME
|
||||||
|
storage_key = AZURE_STORAGE_KEY
|
||||||
|
|
||||||
|
if storage_key:
|
||||||
|
# Configure using the Azure Storage Account Endpoint and Key
|
||||||
|
self.blob_service_client = BlobServiceClient(
|
||||||
|
account_url=self.endpoint, credential=storage_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
|
||||||
|
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
|
||||||
|
self.blob_service_client = BlobServiceClient(
|
||||||
|
account_url=self.endpoint, credential=DefaultAzureCredential()
|
||||||
|
)
|
||||||
|
self.container_client = self.blob_service_client.get_container_client(
|
||||||
|
self.container_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||||
|
"""Handles uploading of the file to Azure Blob Storage."""
|
||||||
|
contents, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||||
|
try:
|
||||||
|
blob_client = self.container_client.get_blob_client(filename)
|
||||||
|
blob_client.upload_blob(contents, overwrite=True)
|
||||||
|
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
|
||||||
|
|
||||||
|
def get_file(self, file_path: str) -> str:
|
||||||
|
"""Handles downloading of the file from Azure Blob Storage."""
|
||||||
|
try:
|
||||||
|
filename = file_path.split("/")[-1]
|
||||||
|
local_file_path = f"{UPLOAD_DIR}/{filename}"
|
||||||
|
blob_client = self.container_client.get_blob_client(filename)
|
||||||
|
with open(local_file_path, "wb") as download_file:
|
||||||
|
download_file.write(blob_client.download_blob().readall())
|
||||||
|
return local_file_path
|
||||||
|
except ResourceNotFoundError as e:
|
||||||
|
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> None:
|
||||||
|
"""Handles deletion of the file from Azure Blob Storage."""
|
||||||
|
try:
|
||||||
|
filename = file_path.split("/")[-1]
|
||||||
|
blob_client = self.container_client.get_blob_client(filename)
|
||||||
|
blob_client.delete_blob()
|
||||||
|
except ResourceNotFoundError as e:
|
||||||
|
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
|
||||||
|
|
||||||
|
# Always delete from local storage
|
||||||
|
LocalStorageProvider.delete_file(file_path)
|
||||||
|
|
||||||
|
def delete_all_files(self) -> None:
|
||||||
|
"""Handles deletion of all files from Azure Blob Storage."""
|
||||||
|
try:
|
||||||
|
blobs = self.container_client.list_blobs()
|
||||||
|
for blob in blobs:
|
||||||
|
self.container_client.delete_blob(blob.name)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
|
||||||
|
|
||||||
|
# Always delete from local storage
|
||||||
|
LocalStorageProvider.delete_all_files()
|
||||||
|
|
||||||
|
|
||||||
def get_storage_provider(storage_provider: str):
|
def get_storage_provider(storage_provider: str):
|
||||||
if storage_provider == "local":
|
if storage_provider == "local":
|
||||||
Storage = LocalStorageProvider()
|
Storage = LocalStorageProvider()
|
||||||
|
|
@ -214,6 +331,8 @@ def get_storage_provider(storage_provider: str):
|
||||||
Storage = S3StorageProvider()
|
Storage = S3StorageProvider()
|
||||||
elif storage_provider == "gcs":
|
elif storage_provider == "gcs":
|
||||||
Storage = GCSStorageProvider()
|
Storage = GCSStorageProvider()
|
||||||
|
elif storage_provider == "azure":
|
||||||
|
Storage = AzureStorageProvider()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
|
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
|
||||||
return Storage
|
return Storage
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from moto import mock_aws
|
||||||
from open_webui.storage import provider
|
from open_webui.storage import provider
|
||||||
from gcp_storage_emulator.server import create_server
|
from gcp_storage_emulator.server import create_server
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
|
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
def mock_upload_dir(monkeypatch, tmp_path):
|
def mock_upload_dir(monkeypatch, tmp_path):
|
||||||
|
|
@ -22,6 +24,7 @@ def test_imports():
|
||||||
provider.LocalStorageProvider
|
provider.LocalStorageProvider
|
||||||
provider.S3StorageProvider
|
provider.S3StorageProvider
|
||||||
provider.GCSStorageProvider
|
provider.GCSStorageProvider
|
||||||
|
provider.AzureStorageProvider
|
||||||
provider.Storage
|
provider.Storage
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -32,6 +35,8 @@ def test_get_storage_provider():
|
||||||
assert isinstance(Storage, provider.S3StorageProvider)
|
assert isinstance(Storage, provider.S3StorageProvider)
|
||||||
Storage = provider.get_storage_provider("gcs")
|
Storage = provider.get_storage_provider("gcs")
|
||||||
assert isinstance(Storage, provider.GCSStorageProvider)
|
assert isinstance(Storage, provider.GCSStorageProvider)
|
||||||
|
Storage = provider.get_storage_provider("azure")
|
||||||
|
assert isinstance(Storage, provider.AzureStorageProvider)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
provider.get_storage_provider("invalid")
|
provider.get_storage_provider("invalid")
|
||||||
|
|
||||||
|
|
@ -48,6 +53,7 @@ def test_class_instantiation():
|
||||||
provider.LocalStorageProvider()
|
provider.LocalStorageProvider()
|
||||||
provider.S3StorageProvider()
|
provider.S3StorageProvider()
|
||||||
provider.GCSStorageProvider()
|
provider.GCSStorageProvider()
|
||||||
|
provider.AzureStorageProvider()
|
||||||
|
|
||||||
|
|
||||||
class TestLocalStorageProvider:
|
class TestLocalStorageProvider:
|
||||||
|
|
@ -181,6 +187,17 @@ class TestS3StorageProvider:
|
||||||
assert not (upload_dir / self.filename).exists()
|
assert not (upload_dir / self.filename).exists()
|
||||||
assert not (upload_dir / self.filename_extra).exists()
|
assert not (upload_dir / self.filename_extra).exists()
|
||||||
|
|
||||||
|
def test_init_without_credentials(self, monkeypatch):
|
||||||
|
"""Test that S3StorageProvider can initialize without explicit credentials."""
|
||||||
|
# Temporarily unset the environment variables
|
||||||
|
monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
|
||||||
|
monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
storage = provider.S3StorageProvider()
|
||||||
|
assert storage.s3_client is not None
|
||||||
|
assert storage.bucket_name == provider.S3_BUCKET_NAME
|
||||||
|
|
||||||
|
|
||||||
class TestGCSStorageProvider:
|
class TestGCSStorageProvider:
|
||||||
Storage = provider.GCSStorageProvider()
|
Storage = provider.GCSStorageProvider()
|
||||||
|
|
@ -272,3 +289,147 @@ class TestGCSStorageProvider:
|
||||||
assert not (upload_dir / self.filename_extra).exists()
|
assert not (upload_dir / self.filename_extra).exists()
|
||||||
assert self.Storage.bucket.get_blob(self.filename) == None
|
assert self.Storage.bucket.get_blob(self.filename) == None
|
||||||
assert self.Storage.bucket.get_blob(self.filename_extra) == None
|
assert self.Storage.bucket.get_blob(self.filename_extra) == None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureStorageProvider:
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def setup_storage(self, monkeypatch):
|
||||||
|
# Create mock Blob Service Client and related clients
|
||||||
|
mock_blob_service_client = MagicMock()
|
||||||
|
mock_container_client = MagicMock()
|
||||||
|
mock_blob_client = MagicMock()
|
||||||
|
|
||||||
|
# Set up return values for the mock
|
||||||
|
mock_blob_service_client.get_container_client.return_value = (
|
||||||
|
mock_container_client
|
||||||
|
)
|
||||||
|
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||||
|
|
||||||
|
# Monkeypatch the Azure classes to return our mocks
|
||||||
|
monkeypatch.setattr(
|
||||||
|
azure.storage.blob,
|
||||||
|
"BlobServiceClient",
|
||||||
|
lambda *args, **kwargs: mock_blob_service_client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
azure.storage.blob,
|
||||||
|
"ContainerClient",
|
||||||
|
lambda *args, **kwargs: mock_container_client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Storage = provider.AzureStorageProvider()
|
||||||
|
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
|
||||||
|
self.Storage.container_name = "my-container"
|
||||||
|
self.file_content = b"test content"
|
||||||
|
self.filename = "test.txt"
|
||||||
|
self.filename_extra = "test_extra.txt"
|
||||||
|
self.file_bytesio_empty = io.BytesIO()
|
||||||
|
|
||||||
|
# Apply mocks to the Storage instance
|
||||||
|
self.Storage.blob_service_client = mock_blob_service_client
|
||||||
|
self.Storage.container_client = mock_container_client
|
||||||
|
|
||||||
|
def test_upload_file(self, monkeypatch, tmp_path):
|
||||||
|
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
# Simulate an error when container does not exist
|
||||||
|
self.Storage.container_client.get_blob_client.side_effect = Exception(
|
||||||
|
"Container does not exist"
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||||
|
|
||||||
|
# Reset side effect and create container
|
||||||
|
self.Storage.container_client.get_blob_client.side_effect = None
|
||||||
|
self.Storage.create_container()
|
||||||
|
contents, azure_file_path = self.Storage.upload_file(
|
||||||
|
io.BytesIO(self.file_content), self.filename
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
|
||||||
|
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
|
||||||
|
self.file_content, overwrite=True
|
||||||
|
)
|
||||||
|
assert contents == self.file_content
|
||||||
|
assert (
|
||||||
|
azure_file_path
|
||||||
|
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||||
|
)
|
||||||
|
assert (upload_dir / self.filename).exists()
|
||||||
|
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
|
||||||
|
|
||||||
|
def test_get_file(self, monkeypatch, tmp_path):
|
||||||
|
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||||
|
self.Storage.create_container()
|
||||||
|
|
||||||
|
# Mock upload behavior
|
||||||
|
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||||
|
# Mock blob download behavior
|
||||||
|
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
|
||||||
|
self.file_content
|
||||||
|
)
|
||||||
|
|
||||||
|
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||||
|
file_path = self.Storage.get_file(file_url)
|
||||||
|
|
||||||
|
assert file_path == str(upload_dir / self.filename)
|
||||||
|
assert (upload_dir / self.filename).exists()
|
||||||
|
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||||
|
|
||||||
|
def test_delete_file(self, monkeypatch, tmp_path):
|
||||||
|
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||||
|
self.Storage.create_container()
|
||||||
|
|
||||||
|
# Mock file upload
|
||||||
|
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||||
|
# Mock deletion
|
||||||
|
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||||
|
|
||||||
|
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||||
|
self.Storage.delete_file(file_url)
|
||||||
|
|
||||||
|
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
|
||||||
|
assert not (upload_dir / self.filename).exists()
|
||||||
|
|
||||||
|
def test_delete_all_files(self, monkeypatch, tmp_path):
|
||||||
|
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||||
|
self.Storage.create_container()
|
||||||
|
|
||||||
|
# Mock file uploads
|
||||||
|
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||||
|
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
|
||||||
|
|
||||||
|
# Mock listing and deletion behavior
|
||||||
|
self.Storage.container_client.list_blobs.return_value = [
|
||||||
|
{"name": self.filename},
|
||||||
|
{"name": self.filename_extra},
|
||||||
|
]
|
||||||
|
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||||
|
|
||||||
|
self.Storage.delete_all_files()
|
||||||
|
|
||||||
|
self.Storage.container_client.list_blobs.assert_called_once()
|
||||||
|
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
|
||||||
|
assert not (upload_dir / self.filename).exists()
|
||||||
|
assert not (upload_dir / self.filename_extra).exists()
|
||||||
|
|
||||||
|
def test_get_file_not_found(self, monkeypatch):
|
||||||
|
self.Storage.create_container()
|
||||||
|
|
||||||
|
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||||
|
# Mock behavior to raise an error for missing blobs
|
||||||
|
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
|
||||||
|
Exception("Blob not found")
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception, match="Blob not found"):
|
||||||
|
self.Storage.get_file(file_url)
|
||||||
|
|
|
||||||
249
backend/open_webui/utils/audit.py
Normal file
|
|
@ -0,0 +1,249 @@
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import re
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Dict,
|
||||||
|
MutableMapping,
|
||||||
|
Optional,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from asgiref.typing import (
|
||||||
|
ASGI3Application,
|
||||||
|
ASGIReceiveCallable,
|
||||||
|
ASGIReceiveEvent,
|
||||||
|
ASGISendCallable,
|
||||||
|
ASGISendEvent,
|
||||||
|
Scope as ASGIScope,
|
||||||
|
)
|
||||||
|
from loguru import logger
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
|
||||||
|
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from loguru import Logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AuditLogEntry:
|
||||||
|
# `Metadata` audit level properties
|
||||||
|
id: str
|
||||||
|
user: dict[str, Any]
|
||||||
|
audit_level: str
|
||||||
|
verb: str
|
||||||
|
request_uri: str
|
||||||
|
user_agent: Optional[str] = None
|
||||||
|
source_ip: Optional[str] = None
|
||||||
|
# `Request` audit level properties
|
||||||
|
request_object: Any = None
|
||||||
|
# `Request Response` level
|
||||||
|
response_object: Any = None
|
||||||
|
response_status_code: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLevel(str, Enum):
|
||||||
|
NONE = "NONE"
|
||||||
|
METADATA = "METADATA"
|
||||||
|
REQUEST = "REQUEST"
|
||||||
|
REQUEST_RESPONSE = "REQUEST_RESPONSE"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogger:
|
||||||
|
"""
|
||||||
|
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
logger (Logger): An instance of Loguru’s logger.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger: "Logger"):
|
||||||
|
self.logger = logger.bind(auditable=True)
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
audit_entry: AuditLogEntry,
|
||||||
|
*,
|
||||||
|
log_level: str = "INFO",
|
||||||
|
extra: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
entry = asdict(audit_entry)
|
||||||
|
|
||||||
|
if extra:
|
||||||
|
entry["extra"] = extra
|
||||||
|
|
||||||
|
self.logger.log(
|
||||||
|
log_level,
|
||||||
|
"",
|
||||||
|
**entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditContext:
|
||||||
|
"""
|
||||||
|
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
request_body (bytearray): Accumulated request payload.
|
||||||
|
response_body (bytearray): Accumulated response payload.
|
||||||
|
max_body_size (int): Maximum number of bytes to capture.
|
||||||
|
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
|
||||||
|
self.request_body = bytearray()
|
||||||
|
self.response_body = bytearray()
|
||||||
|
self.max_body_size = max_body_size
|
||||||
|
self.metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def add_request_chunk(self, chunk: bytes):
|
||||||
|
if len(self.request_body) < self.max_body_size:
|
||||||
|
self.request_body.extend(
|
||||||
|
chunk[: self.max_body_size - len(self.request_body)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_response_chunk(self, chunk: bytes):
|
||||||
|
if len(self.response_body) < self.max_body_size:
|
||||||
|
self.response_body.extend(
|
||||||
|
chunk[: self.max_body_size - len(self.response_body)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLoggingMiddleware:
|
||||||
|
"""
|
||||||
|
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGI3Application,
|
||||||
|
*,
|
||||||
|
excluded_paths: Optional[list[str]] = None,
|
||||||
|
max_body_size: int = MAX_BODY_LOG_SIZE,
|
||||||
|
audit_level: AuditLevel = AuditLevel.NONE,
|
||||||
|
) -> None:
|
||||||
|
self.app = app
|
||||||
|
self.audit_logger = AuditLogger(logger)
|
||||||
|
self.excluded_paths = excluded_paths or []
|
||||||
|
self.max_body_size = max_body_size
|
||||||
|
self.audit_level = audit_level
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
scope: ASGIScope,
|
||||||
|
receive: ASGIReceiveCallable,
|
||||||
|
send: ASGISendCallable,
|
||||||
|
) -> None:
|
||||||
|
if scope["type"] != "http":
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
request = Request(scope=cast(MutableMapping, scope))
|
||||||
|
|
||||||
|
if self._should_skip_auditing(request):
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async with self._audit_context(request) as context:
|
||||||
|
|
||||||
|
async def send_wrapper(message: ASGISendEvent) -> None:
|
||||||
|
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
|
||||||
|
await self._capture_response(message, context)
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
original_receive = receive
|
||||||
|
|
||||||
|
async def receive_wrapper() -> ASGIReceiveEvent:
|
||||||
|
nonlocal original_receive
|
||||||
|
message = await original_receive()
|
||||||
|
|
||||||
|
if self.audit_level in (
|
||||||
|
AuditLevel.REQUEST,
|
||||||
|
AuditLevel.REQUEST_RESPONSE,
|
||||||
|
):
|
||||||
|
await self._capture_request(message, context)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
await self.app(scope, receive_wrapper, send_wrapper)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _audit_context(
|
||||||
|
self, request: Request
|
||||||
|
) -> AsyncGenerator[AuditContext, None]:
|
||||||
|
"""
|
||||||
|
async context manager that ensures that an audit log entry is recorded after the request is processed.
|
||||||
|
"""
|
||||||
|
context = AuditContext()
|
||||||
|
try:
|
||||||
|
yield context
|
||||||
|
finally:
|
||||||
|
await self._log_audit_entry(request, context)
|
||||||
|
|
||||||
|
async def _get_authenticated_user(self, request: Request) -> UserModel:
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
assert auth_header
|
||||||
|
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
def _should_skip_auditing(self, request: Request) -> bool:
|
||||||
|
if (
|
||||||
|
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
|
||||||
|
or AUDIT_LOG_LEVEL == "NONE"
|
||||||
|
or not request.headers.get("authorization")
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
|
||||||
|
pattern = re.compile(
|
||||||
|
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
|
||||||
|
)
|
||||||
|
if pattern.match(request.url.path):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
|
||||||
|
if message["type"] == "http.request":
|
||||||
|
body = message.get("body", b"")
|
||||||
|
context.add_request_chunk(body)
|
||||||
|
|
||||||
|
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
context.metadata["response_status_code"] = message["status"]
|
||||||
|
|
||||||
|
elif message["type"] == "http.response.body":
|
||||||
|
body = message.get("body", b"")
|
||||||
|
context.add_response_chunk(body)
|
||||||
|
|
||||||
|
async def _log_audit_entry(self, request: Request, context: AuditContext):
|
||||||
|
try:
|
||||||
|
user = await self._get_authenticated_user(request)
|
||||||
|
|
||||||
|
entry = AuditLogEntry(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user=user.model_dump(include={"id", "name", "email", "role"}),
|
||||||
|
audit_level=self.audit_level.value,
|
||||||
|
verb=request.method,
|
||||||
|
request_uri=str(request.url),
|
||||||
|
response_status_code=context.metadata.get("response_status_code", None),
|
||||||
|
source_ip=request.client.host if request.client else None,
|
||||||
|
user_agent=request.headers.get("user-agent"),
|
||||||
|
request_object=context.request_body.decode("utf-8", errors="replace"),
|
||||||
|
response_object=context.response_body.decode("utf-8", errors="replace"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.audit_logger.write(entry)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to log audit entry: {str(e)}")
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import jwt
|
import jwt
|
||||||
|
import base64
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
|
|
@ -8,14 +14,22 @@ from typing import Optional, Union, List, Dict
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import WEBUI_SECRET_KEY
|
from open_webui.env import (
|
||||||
|
WEBUI_SECRET_KEY,
|
||||||
|
TRUSTED_SIGNATURE_KEY,
|
||||||
|
STATIC_DIR,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Request, Response, status
|
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||||
|
|
||||||
SESSION_SECRET = WEBUI_SECRET_KEY
|
SESSION_SECRET = WEBUI_SECRET_KEY
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
@ -24,6 +38,67 @@ ALGORITHM = "HS256"
|
||||||
# Auth Utils
|
# Auth Utils
|
||||||
##############
|
##############
|
||||||
|
|
||||||
|
|
||||||
|
def verify_signature(payload: str, signature: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verifies the HMAC signature of the received payload.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
expected_signature = base64.b64encode(
|
||||||
|
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
# Compare securely to prevent timing attacks
|
||||||
|
return hmac.compare_digest(expected_signature, signature)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def override_static(path: str, content: str):
|
||||||
|
# Ensure path is safe
|
||||||
|
if "/" in path or ".." in path:
|
||||||
|
log.error(f"Invalid path: {path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
file_path = os.path.join(STATIC_DIR, path)
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
|
||||||
|
|
||||||
|
|
||||||
|
def get_license_data(app, key):
|
||||||
|
if key:
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
"https://api.openwebui.com/api/v1/license",
|
||||||
|
json={"key": key, "version": "1"},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(res, "ok", False):
|
||||||
|
payload = getattr(res, "json", lambda: {})()
|
||||||
|
for k, v in payload.items():
|
||||||
|
if k == "resources":
|
||||||
|
for p, c in v.items():
|
||||||
|
globals().get("override_static", lambda a, b: None)(p, c)
|
||||||
|
elif k == "count":
|
||||||
|
setattr(app.state, "USER_COUNT", v)
|
||||||
|
elif k == "name":
|
||||||
|
setattr(app.state, "WEBUI_NAME", v)
|
||||||
|
elif k == "metadata":
|
||||||
|
setattr(app.state, "LICENSE_METADATA", v)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
log.error(
|
||||||
|
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
log.exception(f"License: Uncaught Exception: {ex}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
bearer_security = HTTPBearer(auto_error=False)
|
bearer_security = HTTPBearer(auto_error=False)
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
@ -76,6 +151,7 @@ def get_http_authorization_cred(auth_header: str):
|
||||||
|
|
||||||
def get_current_user(
|
def get_current_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||||
):
|
):
|
||||||
token = None
|
token = None
|
||||||
|
|
@ -128,7 +204,10 @@ def get_current_user(
|
||||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
Users.update_user_last_active_by_id(user.id)
|
# Refresh the user's last active timestamp asynchronously
|
||||||
|
# to prevent blocking the request
|
||||||
|
if background_tasks:
|
||||||
|
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,17 @@ from typing import Any, Optional
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request, status
|
||||||
from starlette.responses import Response, StreamingResponse
|
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
from open_webui.socket.main import (
|
from open_webui.socket.main import (
|
||||||
|
sio,
|
||||||
get_event_call,
|
get_event_call,
|
||||||
get_event_emitter,
|
get_event_emitter,
|
||||||
)
|
)
|
||||||
|
|
@ -44,6 +47,10 @@ from open_webui.utils.response import (
|
||||||
convert_response_ollama_to_openai,
|
convert_response_ollama_to_openai,
|
||||||
convert_streaming_response_ollama_to_openai,
|
convert_streaming_response_ollama_to_openai,
|
||||||
)
|
)
|
||||||
|
from open_webui.utils.filter import (
|
||||||
|
get_sorted_filter_ids,
|
||||||
|
process_filter_functions,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||||
|
|
||||||
|
|
@ -53,29 +60,139 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_direct_chat_completion(
|
||||||
|
request: Request,
|
||||||
|
form_data: dict,
|
||||||
|
user: Any,
|
||||||
|
models: dict,
|
||||||
|
):
|
||||||
|
log.info("generate_direct_chat_completion")
|
||||||
|
|
||||||
|
metadata = form_data.pop("metadata", {})
|
||||||
|
|
||||||
|
user_id = metadata.get("user_id")
|
||||||
|
session_id = metadata.get("session_id")
|
||||||
|
request_id = str(uuid.uuid4()) # Generate a unique request ID
|
||||||
|
|
||||||
|
event_caller = get_event_call(metadata)
|
||||||
|
|
||||||
|
channel = f"{user_id}:{session_id}:{request_id}"
|
||||||
|
|
||||||
|
if form_data.get("stream"):
|
||||||
|
q = asyncio.Queue()
|
||||||
|
|
||||||
|
async def message_listener(sid, data):
|
||||||
|
"""
|
||||||
|
Handle received socket messages and push them into the queue.
|
||||||
|
"""
|
||||||
|
await q.put(data)
|
||||||
|
|
||||||
|
# Register the listener
|
||||||
|
sio.on(channel, message_listener)
|
||||||
|
|
||||||
|
# Start processing chat completion in background
|
||||||
|
res = await event_caller(
|
||||||
|
{
|
||||||
|
"type": "request:chat:completion",
|
||||||
|
"data": {
|
||||||
|
"form_data": form_data,
|
||||||
|
"model": models[form_data["model"]],
|
||||||
|
"channel": channel,
|
||||||
|
"session_id": session_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"res: {res}")
|
||||||
|
|
||||||
|
if res.get("status", False):
|
||||||
|
# Define a generator to stream responses
|
||||||
|
async def event_generator():
|
||||||
|
nonlocal q
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await q.get() # Wait for new messages
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "done" in data and data["done"]:
|
||||||
|
break # Stop streaming when 'done' is received
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
elif isinstance(data, str):
|
||||||
|
yield data
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Error in event generator: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Define a background task to run the event generator
|
||||||
|
async def background():
|
||||||
|
try:
|
||||||
|
del sio.handlers["/"][channel]
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Return the streaming response
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(), media_type="text/event-stream", background=background
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(str(res))
|
||||||
|
else:
|
||||||
|
res = await event_caller(
|
||||||
|
{
|
||||||
|
"type": "request:chat:completion",
|
||||||
|
"data": {
|
||||||
|
"form_data": form_data,
|
||||||
|
"model": models[form_data["model"]],
|
||||||
|
"channel": channel,
|
||||||
|
"session_id": session_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if "error" in res and res["error"]:
|
||||||
|
raise Exception(res["error"])
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_completion(
|
async def generate_chat_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: dict,
|
form_data: dict,
|
||||||
user: Any,
|
user: Any,
|
||||||
bypass_filter: bool = False,
|
bypass_filter: bool = False,
|
||||||
):
|
):
|
||||||
|
log.debug(f"generate_chat_completion: {form_data}")
|
||||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
bypass_filter = True
|
bypass_filter = True
|
||||||
|
|
||||||
|
if hasattr(request.state, "metadata"):
|
||||||
|
if "metadata" not in form_data:
|
||||||
|
form_data["metadata"] = request.state.metadata
|
||||||
|
else:
|
||||||
|
form_data["metadata"] = {
|
||||||
|
**form_data["metadata"],
|
||||||
|
**request.state.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
log.debug(f"direct connection to model: {models}")
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
# Process the form_data through the pipeline
|
|
||||||
try:
|
|
||||||
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False):
|
||||||
|
return await generate_direct_chat_completion(
|
||||||
|
request, form_data, user=user, models=models
|
||||||
|
)
|
||||||
|
else:
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if not bypass_filter and user.role == "user":
|
if not bypass_filter and user.role == "user":
|
||||||
try:
|
try:
|
||||||
|
|
@ -83,7 +200,7 @@ async def generate_chat_completion(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if model["owned_by"] == "arena":
|
if model.get("owned_by") == "arena":
|
||||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||||
if model_ids and filter_mode == "exclude":
|
if model_ids and filter_mode == "exclude":
|
||||||
|
|
@ -136,11 +253,14 @@ async def generate_chat_completion(
|
||||||
return await generate_function_chat_completion(
|
return await generate_function_chat_completion(
|
||||||
request, form_data, user=user, models=models
|
request, form_data, user=user, models=models
|
||||||
)
|
)
|
||||||
if model["owned_by"] == "ollama":
|
if model.get("owned_by") == "ollama":
|
||||||
# Using /ollama/api/chat endpoint
|
# Using /ollama/api/chat endpoint
|
||||||
form_data = convert_payload_openai_to_ollama(form_data)
|
form_data = convert_payload_openai_to_ollama(form_data)
|
||||||
response = await generate_ollama_chat_completion(
|
response = await generate_ollama_chat_completion(
|
||||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
request=request,
|
||||||
|
form_data=form_data,
|
||||||
|
user=user,
|
||||||
|
bypass_filter=bypass_filter,
|
||||||
)
|
)
|
||||||
if form_data.get("stream"):
|
if form_data.get("stream"):
|
||||||
response.headers["content-type"] = "text/event-stream"
|
response.headers["content-type"] = "text/event-stream"
|
||||||
|
|
@ -153,7 +273,10 @@ async def generate_chat_completion(
|
||||||
return convert_response_ollama_to_openai(response)
|
return convert_response_ollama_to_openai(response)
|
||||||
else:
|
else:
|
||||||
return await generate_openai_chat_completion(
|
return await generate_openai_chat_completion(
|
||||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
request=request,
|
||||||
|
form_data=form_data,
|
||||||
|
user=user,
|
||||||
|
bypass_filter=bypass_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -162,7 +285,13 @@ chat_completion = generate_chat_completion
|
||||||
|
|
||||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
data = form_data
|
data = form_data
|
||||||
|
|
@ -173,121 +302,48 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = process_pipeline_outlet_filter(request, data, user, models)
|
data = await process_pipeline_outlet_filter(request, data, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Exception(f"Error: {e}")
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(
|
metadata = {
|
||||||
{
|
|
||||||
"chat_id": data["chat_id"],
|
"chat_id": data["chat_id"],
|
||||||
"message_id": data["id"],
|
"message_id": data["id"],
|
||||||
"session_id": data["session_id"],
|
"session_id": data["session_id"],
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
__event_call__ = get_event_call(
|
|
||||||
{
|
|
||||||
"chat_id": data["chat_id"],
|
|
||||||
"message_id": data["id"],
|
|
||||||
"session_id": data["session_id"],
|
|
||||||
"user_id": user.id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_priority(function_id):
|
|
||||||
function = Functions.get_function_by_id(function_id)
|
|
||||||
if function is not None and hasattr(function, "valves"):
|
|
||||||
# TODO: Fix FunctionModel to include vavles
|
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
||||||
if "info" in model and "meta" in model["info"]:
|
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
||||||
filter_ids = list(set(filter_ids))
|
|
||||||
|
|
||||||
enabled_filter_ids = [
|
|
||||||
function.id
|
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
||||||
]
|
|
||||||
filter_ids = [
|
|
||||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort filter_ids by priority, using the get_priority function
|
|
||||||
filter_ids.sort(key=get_priority)
|
|
||||||
|
|
||||||
for filter_id in filter_ids:
|
|
||||||
filter = Functions.get_function_by_id(filter_id)
|
|
||||||
if not filter:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if filter_id in request.app.state.FUNCTIONS:
|
|
||||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
|
||||||
else:
|
|
||||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
|
||||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
||||||
valves = Functions.get_function_valves_by_id(filter_id)
|
|
||||||
function_module.valves = function_module.Valves(
|
|
||||||
**(valves if valves else {})
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(function_module, "outlet"):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
outlet = function_module.outlet
|
|
||||||
|
|
||||||
# Get the signature of the function
|
|
||||||
sig = inspect.signature(outlet)
|
|
||||||
params = {"body": data}
|
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__model__": model,
|
"__event_emitter__": get_event_emitter(metadata),
|
||||||
"__id__": filter_id,
|
"__event_call__": get_event_call(metadata),
|
||||||
"__event_emitter__": __event_emitter__,
|
"__user__": {
|
||||||
"__event_call__": __event_call__,
|
|
||||||
"__request__": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in extra_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
"id": user.id,
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
|
},
|
||||||
|
"__metadata__": metadata,
|
||||||
|
"__request__": request,
|
||||||
|
"__model__": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(function_module, "UserValves"):
|
filter_functions = [
|
||||||
__user__["valves"] = function_module.UserValves(
|
Functions.get_function_by_id(filter_id)
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
for filter_id in get_sorted_filter_ids(model)
|
||||||
filter_id, user.id
|
]
|
||||||
|
|
||||||
|
result, _ = await process_filter_functions(
|
||||||
|
request=request,
|
||||||
|
filter_functions=filter_functions,
|
||||||
|
filter_type="outlet",
|
||||||
|
form_data=data,
|
||||||
|
extra_params=extra_params,
|
||||||
)
|
)
|
||||||
)
|
return result
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(outlet):
|
|
||||||
data = await outlet(**params)
|
|
||||||
else:
|
|
||||||
data = outlet(**params)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Exception(f"Error: {e}")
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||||
if "." in action_id:
|
if "." in action_id:
|
||||||
|
|
@ -300,7 +356,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||||
raise Exception(f"Action not found: {action_id}")
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
data = form_data
|
data = form_data
|
||||||
|
|
@ -375,7 +437,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to get user values: {e}")
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
params = {**params, "__user__": __user__}
|
||||||
|
|
||||||
|
|
|
||||||
210
backend/open_webui/utils/code_interpreter.py
Normal file
|
|
@ -0,0 +1,210 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
class ResultModel(BaseModel):
|
||||||
|
"""
|
||||||
|
Execute Code Result Model
|
||||||
|
"""
|
||||||
|
|
||||||
|
stdout: Optional[str] = ""
|
||||||
|
stderr: Optional[str] = ""
|
||||||
|
result: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class JupyterCodeExecuter:
|
||||||
|
"""
|
||||||
|
Execute code in jupyter notebook
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
code: str,
|
||||||
|
token: str = "",
|
||||||
|
password: str = "",
|
||||||
|
timeout: int = 60,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||||
|
:param code: Code to execute
|
||||||
|
:param token: Jupyter authentication token (optional)
|
||||||
|
:param password: Jupyter password (optional)
|
||||||
|
:param timeout: WebSocket timeout in seconds (default: 60s)
|
||||||
|
"""
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.code = code
|
||||||
|
self.token = token
|
||||||
|
self.password = password
|
||||||
|
self.timeout = timeout
|
||||||
|
self.kernel_id = ""
|
||||||
|
self.session = aiohttp.ClientSession(base_url=self.base_url)
|
||||||
|
self.params = {}
|
||||||
|
self.result = ResultModel()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.kernel_id:
|
||||||
|
try:
|
||||||
|
async with self.session.delete(
|
||||||
|
f"/api/kernels/{self.kernel_id}", params=self.params
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as err:
|
||||||
|
logger.exception("close kernel failed, %s", err)
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def run(self) -> ResultModel:
|
||||||
|
try:
|
||||||
|
await self.sign_in()
|
||||||
|
await self.init_kernel()
|
||||||
|
await self.execute_code()
|
||||||
|
except Exception as err:
|
||||||
|
logger.exception("execute code failed, %s", err)
|
||||||
|
self.result.stderr = f"Error: {err}"
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
async def sign_in(self) -> None:
|
||||||
|
# password authentication
|
||||||
|
if self.password and not self.token:
|
||||||
|
async with self.session.get("/login") as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
xsrf_token = response.cookies["_xsrf"].value
|
||||||
|
if not xsrf_token:
|
||||||
|
raise ValueError("_xsrf token not found")
|
||||||
|
self.session.cookie_jar.update_cookies(response.cookies)
|
||||||
|
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
||||||
|
async with self.session.post(
|
||||||
|
"/login",
|
||||||
|
data={"_xsrf": xsrf_token, "password": self.password},
|
||||||
|
allow_redirects=False,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
self.session.cookie_jar.update_cookies(response.cookies)
|
||||||
|
|
||||||
|
# token authentication
|
||||||
|
if self.token:
|
||||||
|
self.params.update({"token": self.token})
|
||||||
|
|
||||||
|
async def init_kernel(self) -> None:
|
||||||
|
async with self.session.post(
|
||||||
|
url="/api/kernels", params=self.params
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
kernel_data = await response.json()
|
||||||
|
self.kernel_id = kernel_data["id"]
|
||||||
|
|
||||||
|
def init_ws(self) -> (str, dict):
|
||||||
|
ws_base = self.base_url.replace("http", "ws")
|
||||||
|
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
|
||||||
|
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
|
||||||
|
ws_headers = {}
|
||||||
|
if self.password and not self.token:
|
||||||
|
ws_headers = {
|
||||||
|
"Cookie": "; ".join(
|
||||||
|
[
|
||||||
|
f"{cookie.key}={cookie.value}"
|
||||||
|
for cookie in self.session.cookie_jar
|
||||||
|
]
|
||||||
|
),
|
||||||
|
**self.session.headers,
|
||||||
|
}
|
||||||
|
return websocket_url, ws_headers
|
||||||
|
|
||||||
|
async def execute_code(self) -> None:
|
||||||
|
# initialize ws
|
||||||
|
websocket_url, ws_headers = self.init_ws()
|
||||||
|
# execute
|
||||||
|
async with websockets.connect(
|
||||||
|
websocket_url, additional_headers=ws_headers
|
||||||
|
) as ws:
|
||||||
|
await self.execute_in_jupyter(ws)
|
||||||
|
|
||||||
|
async def execute_in_jupyter(self, ws) -> None:
|
||||||
|
# send message
|
||||||
|
msg_id = uuid.uuid4().hex
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "user",
|
||||||
|
"session": uuid.uuid4().hex,
|
||||||
|
"date": "",
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": self.code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"channel": "shell",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# parse message
|
||||||
|
stdout, stderr, result = "", "", []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# wait for message
|
||||||
|
message = await asyncio.wait_for(ws.recv(), self.timeout)
|
||||||
|
message_data = json.loads(message)
|
||||||
|
# msg id not match, skip
|
||||||
|
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
||||||
|
continue
|
||||||
|
# check message type
|
||||||
|
msg_type = message_data.get("msg_type")
|
||||||
|
match msg_type:
|
||||||
|
case "stream":
|
||||||
|
if message_data["content"]["name"] == "stdout":
|
||||||
|
stdout += message_data["content"]["text"]
|
||||||
|
elif message_data["content"]["name"] == "stderr":
|
||||||
|
stderr += message_data["content"]["text"]
|
||||||
|
case "execute_result" | "display_data":
|
||||||
|
data = message_data["content"]["data"]
|
||||||
|
if "image/png" in data:
|
||||||
|
result.append(f"data:image/png;base64,{data['image/png']}")
|
||||||
|
elif "text/plain" in data:
|
||||||
|
result.append(data["text/plain"])
|
||||||
|
case "error":
|
||||||
|
stderr += "\n".join(message_data["content"]["traceback"])
|
||||||
|
case "status":
|
||||||
|
if message_data["content"]["execution_state"] == "idle":
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
stderr += "\nExecution timed out."
|
||||||
|
break
|
||||||
|
self.result.stdout = stdout.strip()
|
||||||
|
self.result.stderr = stderr.strip()
|
||||||
|
self.result.result = "\n".join(result).strip() if result else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_code_jupyter(
|
||||||
|
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
||||||
|
) -> dict:
|
||||||
|
async with JupyterCodeExecuter(
|
||||||
|
base_url, code, token, password, timeout
|
||||||
|
) as executor:
|
||||||
|
result = await executor.run()
|
||||||
|
return result.model_dump()
|
||||||
111
backend/open_webui/utils/filter.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
from open_webui.models.functions import Functions
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_sorted_filter_ids(model: dict):
|
||||||
|
def get_priority(function_id):
|
||||||
|
function = Functions.get_function_by_id(function_id)
|
||||||
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel to include vavles
|
||||||
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
|
filter_ids = list(set(filter_ids))
|
||||||
|
|
||||||
|
enabled_filter_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||||
|
filter_ids.sort(key=get_priority)
|
||||||
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
|
async def process_filter_functions(
|
||||||
|
request, filter_functions, filter_type, form_data, extra_params
|
||||||
|
):
|
||||||
|
skip_files = None
|
||||||
|
|
||||||
|
for function in filter_functions:
|
||||||
|
filter = function
|
||||||
|
filter_id = function.id
|
||||||
|
if not filter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if filter_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||||
|
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
|
# Prepare handler function
|
||||||
|
handler = getattr(function_module, filter_type, None)
|
||||||
|
if not handler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if the function has a file_handler variable
|
||||||
|
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||||
|
skip_files = function_module.file_handler
|
||||||
|
|
||||||
|
# Apply valves to the function
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(filter_id)
|
||||||
|
function_module.valves = function_module.Valves(
|
||||||
|
**(valves if valves else {})
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare parameters
|
||||||
|
sig = inspect.signature(handler)
|
||||||
|
|
||||||
|
params = {"body": form_data}
|
||||||
|
if filter_type == "stream":
|
||||||
|
params = {"event": form_data}
|
||||||
|
|
||||||
|
params = params | {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
**extra_params,
|
||||||
|
"__id__": filter_id,
|
||||||
|
}.items()
|
||||||
|
if k in sig.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle user parameters
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
try:
|
||||||
|
params["__user__"]["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
filter_id, params["__user__"]["id"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Failed to get user values: {e}")
|
||||||
|
|
||||||
|
# Execute handler
|
||||||
|
if inspect.iscoroutinefunction(handler):
|
||||||
|
form_data = await handler(**params)
|
||||||
|
else:
|
||||||
|
form_data = handler(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Handle file cleanup for inlet
|
||||||
|
if skip_files and "files" in form_data.get("metadata", {}):
|
||||||
|
del form_data["metadata"]["files"]
|
||||||
|
|
||||||
|
return form_data, {}
|
||||||
|
|
@ -161,7 +161,7 @@ async def comfyui_generate_image(
|
||||||
seed = (
|
seed = (
|
||||||
payload.seed
|
payload.seed
|
||||||
if payload.seed
|
if payload.seed
|
||||||
else random.randint(0, 18446744073709551614)
|
else random.randint(0, 1125899906842624)
|
||||||
)
|
)
|
||||||
for node_id in node.node_ids:
|
for node_id in node.node_ids:
|
||||||
workflow[node_id]["inputs"][node.key] = seed
|
workflow[node_id]["inputs"][node.key] = seed
|
||||||
|
|
|
||||||
140
backend/open_webui/utils/logger.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from open_webui.env import (
|
||||||
|
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||||
|
AUDIT_LOG_LEVEL,
|
||||||
|
AUDIT_LOGS_FILE_PATH,
|
||||||
|
GLOBAL_LOG_LEVEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from loguru import Record
|
||||||
|
|
||||||
|
|
||||||
|
def stdout_format(record: "Record") -> str:
|
||||||
|
"""
|
||||||
|
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
|
||||||
|
Returns:
|
||||||
|
str: A formatted log string intended for stdout.
|
||||||
|
"""
|
||||||
|
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||||
|
return (
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||||
|
"<level>{level: <8}</level> | "
|
||||||
|
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||||
|
"<level>{message}</level> - {extra[extra_json]}"
|
||||||
|
"\n{exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InterceptHandler(logging.Handler):
|
||||||
|
"""
|
||||||
|
Intercepts log records from Python's standard logging module
|
||||||
|
and redirects them to Loguru's logger.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
"""
|
||||||
|
Called by the standard logging module for each log event.
|
||||||
|
It transforms the standard `LogRecord` into a format compatible with Loguru
|
||||||
|
and passes it to Loguru's logger.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
level = logger.level(record.levelname).name
|
||||||
|
except ValueError:
|
||||||
|
level = record.levelno
|
||||||
|
|
||||||
|
frame, depth = sys._getframe(6), 6
|
||||||
|
while frame and frame.f_code.co_filename == logging.__file__:
|
||||||
|
frame = frame.f_back
|
||||||
|
depth += 1
|
||||||
|
|
||||||
|
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||||
|
level, record.getMessage()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_format(record: "Record"):
|
||||||
|
"""
|
||||||
|
Formats audit log records into a structured JSON string for file output.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
record (Record): A Loguru record containing extra audit data.
|
||||||
|
Returns:
|
||||||
|
str: A JSON-formatted string representing the audit data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audit_data = {
|
||||||
|
"id": record["extra"].get("id", ""),
|
||||||
|
"timestamp": int(record["time"].timestamp()),
|
||||||
|
"user": record["extra"].get("user", dict()),
|
||||||
|
"audit_level": record["extra"].get("audit_level", ""),
|
||||||
|
"verb": record["extra"].get("verb", ""),
|
||||||
|
"request_uri": record["extra"].get("request_uri", ""),
|
||||||
|
"response_status_code": record["extra"].get("response_status_code", 0),
|
||||||
|
"source_ip": record["extra"].get("source_ip", ""),
|
||||||
|
"user_agent": record["extra"].get("user_agent", ""),
|
||||||
|
"request_object": record["extra"].get("request_object", b""),
|
||||||
|
"response_object": record["extra"].get("response_object", b""),
|
||||||
|
"extra": record["extra"].get("extra", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
|
||||||
|
return "{extra[file_extra]}\n"
|
||||||
|
|
||||||
|
|
||||||
|
def start_logger():
|
||||||
|
"""
|
||||||
|
Initializes and configures Loguru's logger with distinct handlers:
|
||||||
|
|
||||||
|
A console (stdout) handler for general log messages (excluding those marked as auditable).
|
||||||
|
An optional file handler for audit logs if audit logging is enabled.
|
||||||
|
Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
|
||||||
|
"""
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level=GLOBAL_LOG_LEVEL,
|
||||||
|
format=stdout_format,
|
||||||
|
filter=lambda record: "auditable" not in record["extra"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if AUDIT_LOG_LEVEL != "NONE":
|
||||||
|
try:
|
||||||
|
logger.add(
|
||||||
|
AUDIT_LOGS_FILE_PATH,
|
||||||
|
level="INFO",
|
||||||
|
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||||
|
compression="zip",
|
||||||
|
format=file_format,
|
||||||
|
filter=lambda record: record["extra"].get("auditable") is True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
|
||||||
|
)
|
||||||
|
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
|
||||||
|
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||||
|
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||||
|
uvicorn_logger.handlers = []
|
||||||
|
for uvicorn_logger_name in ["uvicorn.access"]:
|
||||||
|
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||||
|
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||||
|
uvicorn_logger.handlers = [InterceptHandler()]
|
||||||
|
|
||||||
|
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||||
|
|
@ -2,9 +2,27 @@ import hashlib
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
def deep_update(d, u):
|
||||||
|
for k, v in u.items():
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
d[k] = deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def get_message_list(messages, message_id):
|
def get_message_list(messages, message_id):
|
||||||
|
|
@ -20,7 +38,7 @@ def get_message_list(messages, message_id):
|
||||||
current_message = messages.get(message_id)
|
current_message = messages.get(message_id)
|
||||||
|
|
||||||
if not current_message:
|
if not current_message:
|
||||||
return f"Message ID {message_id} not found in the history."
|
return None
|
||||||
|
|
||||||
# Reconstruct the chain by following the parentId links
|
# Reconstruct the chain by following the parentId links
|
||||||
message_list = []
|
message_list = []
|
||||||
|
|
@ -131,6 +149,44 @@ def add_or_update_system_message(content: str, messages: list[dict]):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def add_or_update_user_message(content: str, messages: list[dict]):
|
||||||
|
"""
|
||||||
|
Adds a new user message at the end of the messages list
|
||||||
|
or updates the existing user message at the end.
|
||||||
|
|
||||||
|
:param msg: The message to be added or appended.
|
||||||
|
:param messages: The list of message dictionaries.
|
||||||
|
:return: The updated list of message dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if messages and messages[-1].get("role") == "user":
|
||||||
|
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
|
||||||
|
else:
|
||||||
|
# Insert at the end
|
||||||
|
messages.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def append_or_update_assistant_message(content: str, messages: list[dict]):
|
||||||
|
"""
|
||||||
|
Adds a new assistant message at the end of the messages list
|
||||||
|
or updates the existing assistant message at the end.
|
||||||
|
|
||||||
|
:param msg: The message to be added or appended.
|
||||||
|
:param messages: The list of message dictionaries.
|
||||||
|
:return: The updated list of message dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if messages and messages[-1].get("role") == "assistant":
|
||||||
|
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
|
||||||
|
else:
|
||||||
|
# Insert at the end
|
||||||
|
messages.append({"role": "assistant", "content": content})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def openai_chat_message_template(model: str):
|
def openai_chat_message_template(model: str):
|
||||||
return {
|
return {
|
||||||
"id": f"{model}-{str(uuid.uuid4())}",
|
"id": f"{model}-{str(uuid.uuid4())}",
|
||||||
|
|
@ -141,13 +197,24 @@ def openai_chat_message_template(model: str):
|
||||||
|
|
||||||
|
|
||||||
def openai_chat_chunk_message_template(
|
def openai_chat_chunk_message_template(
|
||||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
model: str,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
tool_calls: Optional[list[dict]] = None,
|
||||||
|
usage: Optional[dict] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
template = openai_chat_message_template(model)
|
template = openai_chat_message_template(model)
|
||||||
template["object"] = "chat.completion.chunk"
|
template["object"] = "chat.completion.chunk"
|
||||||
if message:
|
|
||||||
template["choices"][0]["delta"] = {"content": message}
|
template["choices"][0]["index"] = 0
|
||||||
else:
|
template["choices"][0]["delta"] = {}
|
||||||
|
|
||||||
|
if content:
|
||||||
|
template["choices"][0]["delta"]["content"] = content
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
if not content and not tool_calls:
|
||||||
template["choices"][0]["finish_reason"] = "stop"
|
template["choices"][0]["finish_reason"] = "stop"
|
||||||
|
|
||||||
if usage:
|
if usage:
|
||||||
|
|
@ -156,12 +223,20 @@ def openai_chat_chunk_message_template(
|
||||||
|
|
||||||
|
|
||||||
def openai_chat_completion_message_template(
|
def openai_chat_completion_message_template(
|
||||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
model: str,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
tool_calls: Optional[list[dict]] = None,
|
||||||
|
usage: Optional[dict] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
template = openai_chat_message_template(model)
|
template = openai_chat_message_template(model)
|
||||||
template["object"] = "chat.completion"
|
template["object"] = "chat.completion"
|
||||||
if message is not None:
|
if message is not None:
|
||||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
template["choices"][0]["message"] = {
|
||||||
|
"content": message,
|
||||||
|
"role": "assistant",
|
||||||
|
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||||
|
}
|
||||||
|
|
||||||
template["choices"][0]["finish_reason"] = "stop"
|
template["choices"][0]["finish_reason"] = "stop"
|
||||||
|
|
||||||
if usage:
|
if usage:
|
||||||
|
|
@ -183,10 +258,11 @@ def get_gravatar_url(email):
|
||||||
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
||||||
|
|
||||||
|
|
||||||
def calculate_sha256(file):
|
def calculate_sha256(file_path, chunk_size):
|
||||||
|
# Compute SHA-256 hash of a file efficiently in chunks
|
||||||
sha256 = hashlib.sha256()
|
sha256 = hashlib.sha256()
|
||||||
# Read the file in chunks to efficiently handle large files
|
with open(file_path, "rb") as f:
|
||||||
for chunk in iter(lambda: file.read(8192), b""):
|
while chunk := f.read(chunk_size):
|
||||||
sha256.update(chunk)
|
sha256.update(chunk)
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
@ -342,7 +418,7 @@ def parse_ollama_modelfile(model_text):
|
||||||
elif param_type is bool:
|
elif param_type is bool:
|
||||||
value = value.lower() == "true"
|
value = value.lower() == "true"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(f"Failed to parse parameter {param}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data["params"][param] = value
|
data["params"][param] = value
|
||||||
|
|
@ -375,3 +451,15 @@ def parse_ollama_modelfile(model_text):
|
||||||
data["params"]["messages"] = messages
|
data["params"]["messages"] = messages
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def convert_logit_bias_input_to_json(user_input):
|
||||||
|
logit_bias_pairs = user_input.split(",")
|
||||||
|
logit_bias_json = {}
|
||||||
|
for pair in logit_bias_pairs:
|
||||||
|
token, bias = pair.split(":")
|
||||||
|
token = str(token.strip())
|
||||||
|
bias = int(bias.strip())
|
||||||
|
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||||
|
logit_bias_json[token] = bias
|
||||||
|
return json.dumps(logit_bias_json)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from open_webui.config import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
|
@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
async def get_all_base_models(request: Request):
|
async def get_all_base_models(request: Request, user: UserModel = None):
|
||||||
function_models = []
|
function_models = []
|
||||||
openai_models = []
|
openai_models = []
|
||||||
ollama_models = []
|
ollama_models = []
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_OPENAI_API:
|
if request.app.state.config.ENABLE_OPENAI_API:
|
||||||
openai_models = await openai.get_all_models(request)
|
openai_models = await openai.get_all_models(request, user=user)
|
||||||
openai_models = openai_models["data"]
|
openai_models = openai_models["data"]
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||||
ollama_models = await ollama.get_all_models(request)
|
ollama_models = await ollama.get_all_models(request, user=user)
|
||||||
ollama_models = [
|
ollama_models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
|
|
@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
async def get_all_models(request):
|
async def get_all_models(request, user: UserModel = None):
|
||||||
models = await get_all_base_models(request)
|
models = await get_all_base_models(request, user=user)
|
||||||
|
|
||||||
# If there are no models, return an empty list
|
# If there are no models, return an empty list
|
||||||
if len(models) == 0:
|
if len(models) == 0:
|
||||||
|
|
@ -142,7 +143,7 @@ async def get_all_models(request):
|
||||||
custom_model.base_model_id == model["id"]
|
custom_model.base_model_id == model["id"]
|
||||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||||
):
|
):
|
||||||
owned_by = model["owned_by"]
|
owned_by = model.get("owned_by", "unknown owner")
|
||||||
if "pipe" in model:
|
if "pipe" in model:
|
||||||
pipe = model["pipe"]
|
pipe = model["pipe"]
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -35,12 +36,20 @@ from open_webui.config import (
|
||||||
AppConfig,
|
AppConfig,
|
||||||
)
|
)
|
||||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||||
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
|
from open_webui.env import (
|
||||||
|
WEBUI_NAME,
|
||||||
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
|
)
|
||||||
from open_webui.utils.misc import parse_duration
|
from open_webui.utils.misc import parse_duration
|
||||||
from open_webui.utils.auth import get_password_hash, create_token
|
from open_webui.utils.auth import get_password_hash, create_token
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||||
|
|
||||||
auth_manager_config = AppConfig()
|
auth_manager_config = AppConfig()
|
||||||
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||||
|
|
@ -61,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||||
|
|
||||||
|
|
||||||
class OAuthManager:
|
class OAuthManager:
|
||||||
def __init__(self):
|
def __init__(self, app):
|
||||||
self.oauth = OAuth()
|
self.oauth = OAuth()
|
||||||
|
self.app = app
|
||||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||||
provider_config["register"](self.oauth)
|
provider_config["register"](self.oauth)
|
||||||
|
|
||||||
|
|
@ -72,17 +82,21 @@ class OAuthManager:
|
||||||
def get_user_role(self, user, user_data):
|
def get_user_role(self, user, user_data):
|
||||||
if user and Users.get_num_users() == 1:
|
if user and Users.get_num_users() == 1:
|
||||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||||
|
log.debug("Assigning the only user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
if not user and Users.get_num_users() == 0:
|
if not user and Users.get_num_users() == 0:
|
||||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||||
|
log.debug("Assigning the first user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
|
|
||||||
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||||
|
log.debug("Running OAUTH Role management")
|
||||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||||
oauth_roles = None
|
oauth_roles = None
|
||||||
role = "pending" # Default/fallback role if no matching roles are found
|
# Default/fallback role if no matching roles are found
|
||||||
|
role = auth_manager_config.DEFAULT_USER_ROLE
|
||||||
|
|
||||||
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
||||||
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
||||||
|
|
@ -92,17 +106,24 @@ class OAuthManager:
|
||||||
claim_data = claim_data.get(nested_claim, {})
|
claim_data = claim_data.get(nested_claim, {})
|
||||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||||
|
|
||||||
|
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||||
|
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||||
|
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
|
||||||
|
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
|
||||||
|
|
||||||
# If any roles are found, check if they match the allowed or admin roles
|
# If any roles are found, check if they match the allowed or admin roles
|
||||||
if oauth_roles:
|
if oauth_roles:
|
||||||
# If role management is enabled, and matching roles are provided, use the roles
|
# If role management is enabled, and matching roles are provided, use the roles
|
||||||
for allowed_role in oauth_allowed_roles:
|
for allowed_role in oauth_allowed_roles:
|
||||||
# If the user has any of the allowed roles, assign the role "user"
|
# If the user has any of the allowed roles, assign the role "user"
|
||||||
if allowed_role in oauth_roles:
|
if allowed_role in oauth_roles:
|
||||||
|
log.debug("Assigned user the user role")
|
||||||
role = "user"
|
role = "user"
|
||||||
break
|
break
|
||||||
for admin_role in oauth_admin_roles:
|
for admin_role in oauth_admin_roles:
|
||||||
# If the user has any of the admin roles, assign the role "admin"
|
# If the user has any of the admin roles, assign the role "admin"
|
||||||
if admin_role in oauth_roles:
|
if admin_role in oauth_roles:
|
||||||
|
log.debug("Assigned user the admin role")
|
||||||
role = "admin"
|
role = "admin"
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
@ -116,16 +137,34 @@ class OAuthManager:
|
||||||
return role
|
return role
|
||||||
|
|
||||||
def update_user_groups(self, user, user_data, default_permissions):
|
def update_user_groups(self, user, user_data, default_permissions):
|
||||||
|
log.debug("Running OAUTH Group management")
|
||||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||||
|
|
||||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
# Nested claim search for groups claim
|
||||||
|
if oauth_claim:
|
||||||
|
claim_data = user_data
|
||||||
|
nested_claims = oauth_claim.split(".")
|
||||||
|
for nested_claim in nested_claims:
|
||||||
|
claim_data = claim_data.get(nested_claim, {})
|
||||||
|
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
|
||||||
|
|
||||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||||
|
|
||||||
|
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||||
|
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||||
|
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||||
|
log.debug(
|
||||||
|
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Remove groups that user is no longer a part of
|
# Remove groups that user is no longer a part of
|
||||||
for group_model in user_current_groups:
|
for group_model in user_current_groups:
|
||||||
if group_model.name not in user_oauth_groups:
|
if group_model.name not in user_oauth_groups:
|
||||||
# Remove group from user
|
# Remove group from user
|
||||||
|
log.debug(
|
||||||
|
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||||
|
)
|
||||||
|
|
||||||
user_ids = group_model.user_ids
|
user_ids = group_model.user_ids
|
||||||
user_ids = [i for i in user_ids if i != user.id]
|
user_ids = [i for i in user_ids if i != user.id]
|
||||||
|
|
@ -151,6 +190,9 @@ class OAuthManager:
|
||||||
gm.name == group_model.name for gm in user_current_groups
|
gm.name == group_model.name for gm in user_current_groups
|
||||||
):
|
):
|
||||||
# Add user to group
|
# Add user to group
|
||||||
|
log.debug(
|
||||||
|
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||||
|
)
|
||||||
|
|
||||||
user_ids = group_model.user_ids
|
user_ids = group_model.user_ids
|
||||||
user_ids.append(user.id)
|
user_ids.append(user.id)
|
||||||
|
|
@ -170,7 +212,7 @@ class OAuthManager:
|
||||||
id=group_model.id, form_data=update_form, overwrite=False
|
id=group_model.id, form_data=update_form, overwrite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_login(self, provider, request):
|
async def handle_login(self, request, provider):
|
||||||
if provider not in OAUTH_PROVIDERS:
|
if provider not in OAUTH_PROVIDERS:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||||
|
|
@ -182,7 +224,7 @@ class OAuthManager:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
return await client.authorize_redirect(request, redirect_uri)
|
return await client.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
async def handle_callback(self, provider, request, response):
|
async def handle_callback(self, request, provider, response):
|
||||||
if provider not in OAUTH_PROVIDERS:
|
if provider not in OAUTH_PROVIDERS:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
client = self.get_client(provider)
|
client = self.get_client(provider)
|
||||||
|
|
@ -192,7 +234,7 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback error: {e}")
|
log.warning(f"OAuth callback error: {e}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
user_data: UserInfo = token.get("userinfo")
|
user_data: UserInfo = token.get("userinfo")
|
||||||
if not user_data:
|
if not user_data or auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data:
|
||||||
user_data: UserInfo = await client.userinfo(token=token)
|
user_data: UserInfo = await client.userinfo(token=token)
|
||||||
if not user_data:
|
if not user_data:
|
||||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||||
|
|
@ -204,11 +246,46 @@ class OAuthManager:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
provider_sub = f"{provider}@{sub}"
|
provider_sub = f"{provider}@{sub}"
|
||||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
email = user_data.get(email_claim, "").lower()
|
email = user_data.get(email_claim, "")
|
||||||
# We currently mandate that email addresses are provided
|
# We currently mandate that email addresses are provided
|
||||||
if not email:
|
if not email:
|
||||||
|
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
||||||
|
if provider == "github":
|
||||||
|
try:
|
||||||
|
access_token = token.get("access_token")
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
"https://api.github.com/user/emails", headers=headers
|
||||||
|
) as resp:
|
||||||
|
if resp.ok:
|
||||||
|
emails = await resp.json()
|
||||||
|
# use the primary email as the user's email
|
||||||
|
primary_email = next(
|
||||||
|
(e["email"] for e in emails if e.get("primary")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if primary_email:
|
||||||
|
email = primary_email
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"No primary email found in GitHub response"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning("Failed to fetch GitHub email")
|
||||||
|
raise HTTPException(
|
||||||
|
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Error fetching GitHub email: {e}")
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
else:
|
||||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
email = email.lower()
|
||||||
if (
|
if (
|
||||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||||
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||||
|
|
@ -236,12 +313,12 @@ class OAuthManager:
|
||||||
Users.update_user_role_by_id(user.id, determined_role)
|
Users.update_user_role_by_id(user.id, determined_role)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
user_count = Users.get_num_users()
|
||||||
|
|
||||||
# If the user does not exist, check if signups are enabled
|
# If the user does not exist, check if signups are enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||||
# Check if an existing user with the same email already exists
|
# Check if an existing user with the same email already exists
|
||||||
existing_user = Users.get_user_by_email(
|
existing_user = Users.get_user_by_email(email)
|
||||||
user_data.get("email", "").lower()
|
|
||||||
)
|
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
|
|
@ -260,24 +337,35 @@ class OAuthManager:
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(picture_url, **get_kwargs) as resp:
|
async with session.get(picture_url, **get_kwargs) as resp:
|
||||||
|
if resp.ok:
|
||||||
picture = await resp.read()
|
picture = await resp.read()
|
||||||
base64_encoded_picture = base64.b64encode(
|
base64_encoded_picture = base64.b64encode(
|
||||||
picture
|
picture
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
guessed_mime_type = mimetypes.guess_type(
|
||||||
|
picture_url
|
||||||
|
)[0]
|
||||||
if guessed_mime_type is None:
|
if guessed_mime_type is None:
|
||||||
# assume JPG, browsers are tolerant enough of image formats
|
# assume JPG, browsers are tolerant enough of image formats
|
||||||
guessed_mime_type = "image/jpeg"
|
guessed_mime_type = "image/jpeg"
|
||||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||||
|
else:
|
||||||
|
picture_url = "/user.png"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
f"Error downloading profile image '{picture_url}': {e}"
|
f"Error downloading profile image '{picture_url}': {e}"
|
||||||
)
|
)
|
||||||
picture_url = ""
|
picture_url = "/user.png"
|
||||||
if not picture_url:
|
if not picture_url:
|
||||||
picture_url = "/user.png"
|
picture_url = "/user.png"
|
||||||
|
|
||||||
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||||
|
|
||||||
|
name = user_data.get(username_claim)
|
||||||
|
if not name:
|
||||||
|
log.warning("Username claim is missing, using email as name")
|
||||||
|
name = email
|
||||||
|
|
||||||
role = self.get_user_role(None, user_data)
|
role = self.get_user_role(None, user_data)
|
||||||
|
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
|
|
@ -285,7 +373,7 @@ class OAuthManager:
|
||||||
password=get_password_hash(
|
password=get_password_hash(
|
||||||
str(uuid.uuid4())
|
str(uuid.uuid4())
|
||||||
), # Random password, not used
|
), # Random password, not used
|
||||||
name=user_data.get(username_claim, "User"),
|
name=name,
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=role,
|
role=role,
|
||||||
oauth_sub=provider_sub,
|
oauth_sub=provider_sub,
|
||||||
|
|
@ -293,6 +381,7 @@ class OAuthManager:
|
||||||
|
|
||||||
if auth_manager_config.WEBHOOK_URL:
|
if auth_manager_config.WEBHOOK_URL:
|
||||||
post_webhook(
|
post_webhook(
|
||||||
|
WEBUI_NAME,
|
||||||
auth_manager_config.WEBHOOK_URL,
|
auth_manager_config.WEBHOOK_URL,
|
||||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
{
|
{
|
||||||
|
|
@ -323,8 +412,8 @@ class OAuthManager:
|
||||||
key="token",
|
key="token",
|
||||||
value=jwt_token,
|
value=jwt_token,
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
if ENABLE_OAUTH_SIGNUP.value:
|
||||||
|
|
@ -333,12 +422,9 @@ class OAuthManager:
|
||||||
key="oauth_id_token",
|
key="oauth_id_token",
|
||||||
value=oauth_id_token,
|
value=oauth_id_token,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
# Redirect back to the frontend with the JWT token
|
# Redirect back to the frontend with the JWT token
|
||||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
||||||
|
|
||||||
oauth_manager = OAuthManager()
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,27 @@
|
||||||
from open_webui.utils.task import prompt_template
|
from open_webui.utils.task import prompt_template, prompt_variables_template
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
# inplace function: form_data is modified
|
# inplace function: form_data is modified
|
||||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
def apply_model_system_prompt_to_body(
|
||||||
|
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
|
||||||
|
) -> dict:
|
||||||
system = params.get("system", None)
|
system = params.get("system", None)
|
||||||
if not system:
|
if not system:
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
|
# Metadata (WebUI Usage)
|
||||||
|
if metadata:
|
||||||
|
variables = metadata.get("variables", {})
|
||||||
|
if variables:
|
||||||
|
system = prompt_variables_template(system, variables)
|
||||||
|
|
||||||
|
# Legacy (API Usage)
|
||||||
if user:
|
if user:
|
||||||
template_params = {
|
template_params = {
|
||||||
"user_name": user.name,
|
"user_name": user.name,
|
||||||
|
|
@ -19,7 +29,9 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
template_params = {}
|
template_params = {}
|
||||||
|
|
||||||
system = prompt_template(system, **template_params)
|
system = prompt_template(system, **template_params)
|
||||||
|
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data["messages"] = add_or_update_system_message(
|
||||||
system, form_data.get("messages", [])
|
system, form_data.get("messages", [])
|
||||||
)
|
)
|
||||||
|
|
@ -50,43 +62,55 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||||
"reasoning_effort": str,
|
"reasoning_effort": str,
|
||||||
"seed": lambda x: x,
|
"seed": lambda x: x,
|
||||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||||
|
"logit_bias": lambda x: x,
|
||||||
}
|
}
|
||||||
return apply_model_params_to_body(params, form_data, mappings)
|
return apply_model_params_to_body(params, form_data, mappings)
|
||||||
|
|
||||||
|
|
||||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||||
opts = [
|
# Convert OpenAI parameter names to Ollama parameter names if needed.
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"seed",
|
|
||||||
"mirostat",
|
|
||||||
"mirostat_eta",
|
|
||||||
"mirostat_tau",
|
|
||||||
"num_ctx",
|
|
||||||
"num_batch",
|
|
||||||
"num_keep",
|
|
||||||
"repeat_last_n",
|
|
||||||
"tfs_z",
|
|
||||||
"top_k",
|
|
||||||
"min_p",
|
|
||||||
"use_mmap",
|
|
||||||
"use_mlock",
|
|
||||||
"num_thread",
|
|
||||||
"num_gpu",
|
|
||||||
]
|
|
||||||
mappings = {i: lambda x: x for i in opts}
|
|
||||||
form_data = apply_model_params_to_body(params, form_data, mappings)
|
|
||||||
|
|
||||||
name_differences = {
|
name_differences = {
|
||||||
"max_tokens": "num_predict",
|
"max_tokens": "num_predict",
|
||||||
"frequency_penalty": "repeat_penalty",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value in name_differences.items():
|
for key, value in name_differences.items():
|
||||||
if (param := params.get(key, None)) is not None:
|
if (param := params.get(key, None)) is not None:
|
||||||
form_data[value] = param
|
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||||
|
params[value] = params[key]
|
||||||
|
del params[key]
|
||||||
|
|
||||||
return form_data
|
# See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8
|
||||||
|
mappings = {
|
||||||
|
"temperature": float,
|
||||||
|
"top_p": float,
|
||||||
|
"seed": lambda x: x,
|
||||||
|
"mirostat": int,
|
||||||
|
"mirostat_eta": float,
|
||||||
|
"mirostat_tau": float,
|
||||||
|
"num_ctx": int,
|
||||||
|
"num_batch": int,
|
||||||
|
"num_keep": int,
|
||||||
|
"num_predict": int,
|
||||||
|
"repeat_last_n": int,
|
||||||
|
"top_k": int,
|
||||||
|
"min_p": float,
|
||||||
|
"typical_p": float,
|
||||||
|
"repeat_penalty": float,
|
||||||
|
"presence_penalty": float,
|
||||||
|
"frequency_penalty": float,
|
||||||
|
"penalize_newline": bool,
|
||||||
|
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||||
|
"numa": bool,
|
||||||
|
"num_gpu": int,
|
||||||
|
"main_gpu": int,
|
||||||
|
"low_vram": bool,
|
||||||
|
"vocab_only": bool,
|
||||||
|
"use_mmap": bool,
|
||||||
|
"use_mlock": bool,
|
||||||
|
"num_thread": int,
|
||||||
|
}
|
||||||
|
|
||||||
|
return apply_model_params_to_body(params, form_data, mappings)
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||||
|
|
@ -97,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||||
new_message = {"role": message["role"]}
|
new_message = {"role": message["role"]}
|
||||||
|
|
||||||
content = message.get("content", [])
|
content = message.get("content", [])
|
||||||
|
tool_calls = message.get("tool_calls", None)
|
||||||
|
tool_call_id = message.get("tool_call_id", None)
|
||||||
|
|
||||||
# Check if the content is a string (just a simple message)
|
# Check if the content is a string (just a simple message)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str) and not tool_calls:
|
||||||
# If the content is a string, it's pure text
|
# If the content is a string, it's pure text
|
||||||
new_message["content"] = content
|
new_message["content"] = content
|
||||||
|
|
||||||
|
# If message is a tool call, add the tool call id to the message
|
||||||
|
if tool_call_id:
|
||||||
|
new_message["tool_call_id"] = tool_call_id
|
||||||
|
|
||||||
|
elif tool_calls:
|
||||||
|
# If tool calls are present, add them to the message
|
||||||
|
ollama_tool_calls = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
ollama_tool_call = {
|
||||||
|
"index": tool_call.get("index", 0),
|
||||||
|
"id": tool_call.get("id", None),
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("function", {}).get("name", ""),
|
||||||
|
"arguments": json.loads(
|
||||||
|
tool_call.get("function", {}).get("arguments", {})
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ollama_tool_calls.append(ollama_tool_call)
|
||||||
|
new_message["tool_calls"] = ollama_tool_calls
|
||||||
|
|
||||||
|
# Put the content to empty string (Ollama requires an empty string for tool calls)
|
||||||
|
new_message["content"] = ""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
|
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
|
||||||
content_text = ""
|
content_text = ""
|
||||||
|
|
@ -155,37 +206,38 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||||
)
|
)
|
||||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||||
|
|
||||||
|
if "tools" in openai_payload:
|
||||||
|
ollama_payload["tools"] = openai_payload["tools"]
|
||||||
|
|
||||||
if "format" in openai_payload:
|
if "format" in openai_payload:
|
||||||
ollama_payload["format"] = openai_payload["format"]
|
ollama_payload["format"] = openai_payload["format"]
|
||||||
|
|
||||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||||
ollama_options = {}
|
|
||||||
|
|
||||||
if openai_payload.get("options"):
|
if openai_payload.get("options"):
|
||||||
ollama_payload["options"] = openai_payload["options"]
|
ollama_payload["options"] = openai_payload["options"]
|
||||||
ollama_options = openai_payload["options"]
|
ollama_options = openai_payload["options"]
|
||||||
|
|
||||||
# Handle parameters which map directly
|
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||||
for param in ["temperature", "top_p", "seed"]:
|
if "max_tokens" in ollama_options:
|
||||||
if param in openai_payload:
|
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||||
ollama_options[param] = openai_payload[param]
|
del ollama_options[
|
||||||
|
"max_tokens"
|
||||||
|
] # To prevent Ollama warning of invalid option provided
|
||||||
|
|
||||||
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||||
if "max_completion_tokens" in openai_payload:
|
if "system" in ollama_options:
|
||||||
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
|
ollama_payload["system"] = ollama_options["system"]
|
||||||
elif "max_tokens" in openai_payload:
|
del ollama_options[
|
||||||
ollama_options["num_predict"] = openai_payload["max_tokens"]
|
"system"
|
||||||
|
] # To prevent Ollama warning of invalid option provided
|
||||||
|
|
||||||
# Handle frequency / presence_penalty, which needs renaming and checking
|
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||||
if "frequency_penalty" in openai_payload:
|
if "stop" in openai_payload:
|
||||||
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
|
ollama_options = ollama_payload.get("options", {})
|
||||||
|
ollama_options["stop"] = openai_payload.get("stop")
|
||||||
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
|
|
||||||
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
|
|
||||||
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
|
|
||||||
|
|
||||||
# Add options to payload if any have been set
|
|
||||||
if ollama_options:
|
|
||||||
ollama_payload["options"] = ollama_options
|
ollama_payload["options"] = ollama_options
|
||||||
|
|
||||||
|
if "metadata" in openai_payload:
|
||||||
|
ollama_payload["metadata"] = openai_payload["metadata"]
|
||||||
|
|
||||||
return ollama_payload
|
return ollama_payload
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
from html import escape
|
||||||
|
|
||||||
from markdown import markdown
|
from markdown import markdown
|
||||||
|
|
||||||
|
|
@ -41,13 +42,13 @@ class PDFGenerator:
|
||||||
|
|
||||||
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
||||||
"""Build HTML for a single message."""
|
"""Build HTML for a single message."""
|
||||||
role = message.get("role", "user")
|
role = escape(message.get("role", "user"))
|
||||||
content = message.get("content", "")
|
content = escape(message.get("content", ""))
|
||||||
timestamp = message.get("timestamp")
|
timestamp = message.get("timestamp")
|
||||||
|
|
||||||
model = message.get("model") if role == "assistant" else ""
|
model = escape(message.get("model") if role == "assistant" else "")
|
||||||
|
|
||||||
date_str = self.format_timestamp(timestamp) if timestamp else ""
|
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
|
||||||
|
|
||||||
# extends pymdownx extension to convert markdown to html.
|
# extends pymdownx extension to convert markdown to html.
|
||||||
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
||||||
|
|
@ -76,6 +77,7 @@ class PDFGenerator:
|
||||||
|
|
||||||
def _generate_html_body(self) -> str:
|
def _generate_html_body(self) -> str:
|
||||||
"""Generate the full HTML body for the PDF."""
|
"""Generate the full HTML body for the PDF."""
|
||||||
|
escaped_title = escape(self.form_data.title)
|
||||||
return f"""
|
return f"""
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
|
|
@ -84,7 +86,7 @@ class PDFGenerator:
|
||||||
<body>
|
<body>
|
||||||
<div>
|
<div>
|
||||||
<div>
|
<div>
|
||||||
<h2>{self.form_data.title}</h2>
|
<h2>{escaped_title}</h2>
|
||||||
{self.messages_html}
|
{self.messages_html}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -108,7 +110,7 @@ class PDFGenerator:
|
||||||
# When running using `pip install -e .` the static directory is in the site packages.
|
# When running using `pip install -e .` the static directory is in the site packages.
|
||||||
# This path only works if `open-webui serve` is run from the root of this project.
|
# This path only works if `open-webui serve` is run from the root of this project.
|
||||||
if not FONTS_DIR.exists():
|
if not FONTS_DIR.exists():
|
||||||
FONTS_DIR = Path("./backend/static/fonts")
|
FONTS_DIR = Path(".") / "backend" / "static" / "fonts"
|
||||||
|
|
||||||
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
|
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
|
||||||
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
|
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def extract_frontmatter(content):
|
||||||
frontmatter[key.strip()] = value.strip()
|
frontmatter[key.strip()] = value.strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
log.exception(f"Failed to extract frontmatter: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return frontmatter
|
return frontmatter
|
||||||
|
|
@ -167,9 +167,14 @@ def load_function_module_by_id(function_id, content=None):
|
||||||
|
|
||||||
def install_frontmatter_requirements(requirements):
|
def install_frontmatter_requirements(requirements):
|
||||||
if requirements:
|
if requirements:
|
||||||
|
try:
|
||||||
req_list = [req.strip() for req in requirements.split(",")]
|
req_list = [req.strip() for req in requirements.split(",")]
|
||||||
for req in req_list:
|
for req in req_list:
|
||||||
log.info(f"Installing requirement: {req}")
|
log.info(f"Installing requirement: {req}")
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error installing package: {req}")
|
||||||
|
raise e
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.info("No requirements found in frontmatter.")
|
log.info("No requirements found in frontmatter.")
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,31 @@
|
||||||
import json
|
import json
|
||||||
|
from uuid import uuid4
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
openai_chat_chunk_message_template,
|
openai_chat_chunk_message_template,
|
||||||
openai_chat_completion_message_template,
|
openai_chat_completion_message_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||||
model = ollama_response.get("model", "ollama")
|
openai_tool_calls = []
|
||||||
message_content = ollama_response.get("message", {}).get("content", "")
|
for tool_call in tool_calls:
|
||||||
|
openai_tool_call = {
|
||||||
response = openai_chat_completion_message_template(model, message_content)
|
"index": tool_call.get("index", 0),
|
||||||
return response
|
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("function", {}).get("name", ""),
|
||||||
|
"arguments": json.dumps(
|
||||||
|
tool_call.get("function", {}).get("arguments", {})
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openai_tool_calls.append(openai_tool_call)
|
||||||
|
return openai_tool_calls
|
||||||
|
|
||||||
|
|
||||||
async def convert_streaming_response_ollama_to_openai(ollama_streaming_response):
|
def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||||
async for data in ollama_streaming_response.body_iterator:
|
return {
|
||||||
data = json.loads(data)
|
|
||||||
|
|
||||||
model = data.get("model", "ollama")
|
|
||||||
message_content = data.get("message", {}).get("content", "")
|
|
||||||
done = data.get("done", False)
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if done:
|
|
||||||
usage = {
|
|
||||||
"response_token/s": (
|
"response_token/s": (
|
||||||
round(
|
round(
|
||||||
(
|
(
|
||||||
|
|
@ -55,16 +57,68 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||||
"total_duration": data.get("total_duration", 0),
|
"total_duration": data.get("total_duration", 0),
|
||||||
"load_duration": data.get("load_duration", 0),
|
"load_duration": data.get("load_duration", 0),
|
||||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||||
|
"prompt_tokens": int(
|
||||||
|
data.get("prompt_eval_count", 0)
|
||||||
|
), # This is the OpenAI compatible key
|
||||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||||
"eval_count": data.get("eval_count", 0),
|
"eval_count": data.get("eval_count", 0),
|
||||||
|
"completion_tokens": int(
|
||||||
|
data.get("eval_count", 0)
|
||||||
|
), # This is the OpenAI compatible key
|
||||||
"eval_duration": data.get("eval_duration", 0),
|
"eval_duration": data.get("eval_duration", 0),
|
||||||
"approximate_total": (
|
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
|
||||||
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
|
(data.get("total_duration", 0) or 0) // 1_000_000_000
|
||||||
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
|
),
|
||||||
|
"total_tokens": int( # This is the OpenAI compatible key
|
||||||
|
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||||
|
),
|
||||||
|
"completion_tokens_details": { # This is the OpenAI compatible key
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||||
|
model = ollama_response.get("model", "ollama")
|
||||||
|
message_content = ollama_response.get("message", {}).get("content", "")
|
||||||
|
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||||
|
openai_tool_calls = None
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||||
|
|
||||||
|
data = ollama_response
|
||||||
|
|
||||||
|
usage = convert_ollama_usage_to_openai(data)
|
||||||
|
|
||||||
|
response = openai_chat_completion_message_template(
|
||||||
|
model, message_content, openai_tool_calls, usage
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_streaming_response_ollama_to_openai(ollama_streaming_response):
|
||||||
|
async for data in ollama_streaming_response.body_iterator:
|
||||||
|
data = json.loads(data)
|
||||||
|
|
||||||
|
model = data.get("model", "ollama")
|
||||||
|
message_content = data.get("message", {}).get("content", None)
|
||||||
|
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||||
|
openai_tool_calls = None
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||||
|
|
||||||
|
done = data.get("done", False)
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if done:
|
||||||
|
usage = convert_ollama_usage_to_openai(data)
|
||||||
|
|
||||||
data = openai_chat_chunk_message_template(
|
data = openai_chat_chunk_message_template(
|
||||||
model, message_content if not done else None, usage
|
model, message_content, openai_tool_calls, usage
|
||||||
)
|
)
|
||||||
|
|
||||||
line = f"data: {json.dumps(data)}\n\n"
|
line = f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ def get_task_model_id(
|
||||||
# Set the task model
|
# Set the task model
|
||||||
task_model_id = default_model_id
|
task_model_id = default_model_id
|
||||||
# Check if the user has a custom task model and use that model
|
# Check if the user has a custom task model and use that model
|
||||||
if models[task_model_id]["owned_by"] == "ollama":
|
if models[task_model_id].get("owned_by") == "ollama":
|
||||||
if task_model and task_model in models:
|
if task_model and task_model in models:
|
||||||
task_model_id = task_model
|
task_model_id = task_model
|
||||||
else:
|
else:
|
||||||
|
|
@ -32,6 +32,12 @@ def get_task_model_id(
|
||||||
return task_model_id
|
return task_model_id
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
|
||||||
|
for variable, value in variables.items():
|
||||||
|
template = template.replace(variable, value)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
def prompt_template(
|
def prompt_template(
|
||||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -98,7 +104,7 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def replace_messages_variable(
|
def replace_messages_variable(
|
||||||
template: str, messages: Optional[list[str]] = None
|
template: str, messages: Optional[list[dict]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
def replacement_function(match):
|
def replacement_function(match):
|
||||||
full_match = match.group(0)
|
full_match = match.group(0)
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,12 @@ def get_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
for spec in tools.specs:
|
for spec in tools.specs:
|
||||||
|
# TODO: Fix hack for OpenAI API
|
||||||
|
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||||
|
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||||
|
if val["type"] == "str":
|
||||||
|
val["type"] = "string"
|
||||||
|
|
||||||
# Remove internal parameters
|
# Remove internal parameters
|
||||||
spec["parameters"]["properties"] = {
|
spec["parameters"]["properties"] = {
|
||||||
key: val
|
key: val
|
||||||
|
|
@ -73,6 +79,13 @@ def get_tools(
|
||||||
# convert to function that takes only model params and inserts custom params
|
# convert to function that takes only model params and inserts custom params
|
||||||
original_func = getattr(module, function_name)
|
original_func = getattr(module, function_name)
|
||||||
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||||
|
|
||||||
|
if callable.__doc__ and callable.__doc__.strip() != "":
|
||||||
|
s = re.split(":(param|return)", callable.__doc__, 1)
|
||||||
|
spec["description"] = s[0]
|
||||||
|
else:
|
||||||
|
spec["description"] = function_name
|
||||||
|
|
||||||
# TODO: This needs to be a pydantic model
|
# TODO: This needs to be a pydantic model
|
||||||
tool_dict = {
|
tool_dict = {
|
||||||
"toolkit_id": tool_id,
|
"toolkit_id": tool_id,
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@ import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
|
from open_webui.config import WEBUI_FAVICON_URL
|
||||||
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||||
|
|
||||||
|
|
||||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
||||||
try:
|
try:
|
||||||
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
||||||
payload = {}
|
payload = {}
|
||||||
|
|
@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||||
"sections": [
|
"sections": [
|
||||||
{
|
{
|
||||||
"activityTitle": message,
|
"activityTitle": message,
|
||||||
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
|
"activitySubtitle": f"{name} ({VERSION}) - {action}",
|
||||||
"activityImage": WEBUI_FAVICON_URL,
|
"activityImage": WEBUI_FAVICON_URL,
|
||||||
"facts": facts,
|
"facts": facts,
|
||||||
"markdown": True,
|
"markdown": True,
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,26 @@
|
||||||
fastapi==0.111.0
|
fastapi==0.115.7
|
||||||
uvicorn[standard]==0.30.6
|
uvicorn[standard]==0.34.0
|
||||||
pydantic==2.9.2
|
pydantic==2.10.6
|
||||||
python-multipart==0.0.18
|
python-multipart==0.0.18
|
||||||
|
|
||||||
Flask==3.1.0
|
|
||||||
Flask-Cors==5.0.0
|
|
||||||
|
|
||||||
python-socketio==5.11.3
|
python-socketio==5.11.3
|
||||||
python-jose==3.3.0
|
python-jose==3.4.0
|
||||||
passlib[bcrypt]==1.7.4
|
passlib[bcrypt]==1.7.4
|
||||||
|
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
aiohttp==3.11.8
|
aiohttp==3.11.11
|
||||||
async-timeout
|
async-timeout
|
||||||
aiocache
|
aiocache
|
||||||
aiofiles
|
aiofiles
|
||||||
|
|
||||||
sqlalchemy==2.0.32
|
sqlalchemy==2.0.38
|
||||||
alembic==1.14.0
|
alembic==1.14.0
|
||||||
peewee==3.17.8
|
peewee==3.17.9
|
||||||
peewee-migrate==1.12.2
|
peewee-migrate==1.12.2
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
pgvector==0.3.5
|
pgvector==0.3.5
|
||||||
PyMySQL==1.1.1
|
PyMySQL==1.1.1
|
||||||
bcrypt==4.2.0
|
bcrypt==4.3.0
|
||||||
|
|
||||||
pymongo
|
pymongo
|
||||||
redis
|
redis
|
||||||
|
|
@ -32,20 +29,27 @@ boto3==1.35.53
|
||||||
argon2-cffi==23.1.0
|
argon2-cffi==23.1.0
|
||||||
APScheduler==3.10.4
|
APScheduler==3.10.4
|
||||||
|
|
||||||
|
RestrictedPython==8.0
|
||||||
|
|
||||||
|
loguru==0.7.2
|
||||||
|
asgiref==3.8.1
|
||||||
|
|
||||||
# AI libraries
|
# AI libraries
|
||||||
openai
|
openai
|
||||||
anthropic
|
anthropic
|
||||||
google-generativeai==0.7.2
|
google-generativeai==0.7.2
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
langchain==0.3.7
|
langchain==0.3.19
|
||||||
langchain-community==0.3.7
|
langchain-community==0.3.18
|
||||||
|
|
||||||
fake-useragent==1.5.1
|
fake-useragent==1.5.1
|
||||||
chromadb==0.6.2
|
chromadb==0.6.2
|
||||||
pymilvus==2.5.0
|
pymilvus==2.5.0
|
||||||
qdrant-client~=1.12.0
|
qdrant-client~=1.12.0
|
||||||
opensearch-py==2.7.1
|
opensearch-py==2.8.0
|
||||||
|
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||||
|
elasticsearch==8.17.1
|
||||||
|
|
||||||
|
|
||||||
transformers
|
transformers
|
||||||
|
|
@ -57,10 +61,10 @@ einops==0.8.0
|
||||||
ftfy==6.2.3
|
ftfy==6.2.3
|
||||||
pypdf==4.3.1
|
pypdf==4.3.1
|
||||||
fpdf2==2.8.2
|
fpdf2==2.8.2
|
||||||
pymdown-extensions==10.11.2
|
pymdown-extensions==10.14.2
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
python-pptx==1.0.0
|
python-pptx==1.0.0
|
||||||
unstructured==0.15.9
|
unstructured==0.16.17
|
||||||
nltk==3.9.1
|
nltk==3.9.1
|
||||||
Markdown==3.7
|
Markdown==3.7
|
||||||
pypandoc==1.13
|
pypandoc==1.13
|
||||||
|
|
@ -71,25 +75,26 @@ xlrd==2.0.1
|
||||||
validators==0.34.0
|
validators==0.34.0
|
||||||
psutil
|
psutil
|
||||||
sentencepiece
|
sentencepiece
|
||||||
soundfile==0.12.1
|
soundfile==0.13.1
|
||||||
|
azure-ai-documentintelligence==1.0.0
|
||||||
|
|
||||||
opencv-python-headless==4.10.0.84
|
opencv-python-headless==4.11.0.86
|
||||||
rapidocr-onnxruntime==1.3.24
|
rapidocr-onnxruntime==1.3.24
|
||||||
rank-bm25==0.2.2
|
rank-bm25==0.2.2
|
||||||
|
|
||||||
faster-whisper==1.0.3
|
faster-whisper==1.1.1
|
||||||
|
|
||||||
PyJWT[crypto]==2.10.1
|
PyJWT[crypto]==2.10.1
|
||||||
authlib==1.3.2
|
authlib==1.4.1
|
||||||
|
|
||||||
black==24.8.0
|
black==25.1.0
|
||||||
langfuse==2.44.0
|
langfuse==2.44.0
|
||||||
youtube-transcript-api==0.6.3
|
youtube-transcript-api==0.6.3
|
||||||
pytube==15.0.0
|
pytube==15.0.0
|
||||||
|
|
||||||
extract_msg
|
extract_msg
|
||||||
pydub
|
pydub
|
||||||
duckduckgo-search~=7.2.1
|
duckduckgo-search~=7.3.2
|
||||||
|
|
||||||
## Google Drive
|
## Google Drive
|
||||||
google-api-python-client
|
google-api-python-client
|
||||||
|
|
@ -104,5 +109,12 @@ pytest-docker~=3.1.1
|
||||||
googleapis-common-protos==1.63.2
|
googleapis-common-protos==1.63.2
|
||||||
google-cloud-storage==2.19.0
|
google-cloud-storage==2.19.0
|
||||||
|
|
||||||
|
azure-identity==1.20.0
|
||||||
|
azure-storage-blob==12.24.1
|
||||||
|
|
||||||
|
|
||||||
## LDAP
|
## LDAP
|
||||||
ldap3==2.9.1
|
ldap3==2.9.1
|
||||||
|
|
||||||
|
## Firecrawl
|
||||||
|
firecrawl-py==1.12.0
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,17 @@
|
||||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
cd "$SCRIPT_DIR" || exit
|
cd "$SCRIPT_DIR" || exit
|
||||||
|
|
||||||
|
# Add conditional Playwright browser installation
|
||||||
|
if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
|
||||||
|
if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then
|
||||||
|
echo "Installing Playwright browsers..."
|
||||||
|
playwright install chromium
|
||||||
|
playwright install-deps chromium
|
||||||
|
fi
|
||||||
|
|
||||||
|
python -c "import nltk; nltk.download('punkt_tab')"
|
||||||
|
fi
|
||||||
|
|
||||||
KEY_FILE=.webui_secret_key
|
KEY_FILE=.webui_secret_key
|
||||||
|
|
||||||
PORT="${PORT:-8080}"
|
PORT="${PORT:-8080}"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION
|
||||||
SET "SCRIPT_DIR=%~dp0"
|
SET "SCRIPT_DIR=%~dp0"
|
||||||
cd /d "%SCRIPT_DIR%" || exit /b
|
cd /d "%SCRIPT_DIR%" || exit /b
|
||||||
|
|
||||||
|
:: Add conditional Playwright browser installation
|
||||||
|
IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" (
|
||||||
|
IF "%PLAYWRIGHT_WS_URI%" == "" (
|
||||||
|
echo Installing Playwright browsers...
|
||||||
|
playwright install chromium
|
||||||
|
playwright install-deps chromium
|
||||||
|
)
|
||||||
|
|
||||||
|
python -c "import nltk; nltk.download('punkt_tab')"
|
||||||
|
)
|
||||||
|
|
||||||
SET "KEY_FILE=.webui_secret_key"
|
SET "KEY_FILE=.webui_secret_key"
|
||||||
IF "%PORT%"=="" SET PORT=8080
|
IF "%PORT%"=="" SET PORT=8080
|
||||||
IF "%HOST%"=="" SET HOST=0.0.0.0
|
IF "%HOST%"=="" SET HOST=0.0.0.0
|
||||||
|
|
|
||||||
10
docker-compose.playwright.yaml
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
services:
|
||||||
|
playwright:
|
||||||
|
image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt
|
||||||
|
container_name: playwright
|
||||||
|
command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0
|
||||||
|
|
||||||
|
open-webui:
|
||||||
|
environment:
|
||||||
|
- 'RAG_WEB_LOADER_ENGINE=playwright'
|
||||||
|
- 'PLAYWRIGHT_WS_URI=ws://playwright:3000'
|
||||||