diff --git a/CHANGELOG.md b/CHANGELOG.md
index bad83dc1ef..126f14e006 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.5.11] - 2025-02-13
+
+### Added
+
+- **🎤 Kokoro-JS TTS Support**: A new on-device, high-quality text-to-speech engine has been integrated, vastly improving voice generation quality—everything runs directly in your browser.
+- **🐍 Jupyter Notebook Support in Code Interpreter**: Now, you can configure Code Interpreter to run Python code not only via Pyodide but also through Jupyter, offering a more robust coding environment for AI-driven computations and analysis.
+- **🔗 Direct API Connections for Private & Local Inference**: You can now connect Open WebUI to your private or localhost API inference endpoints. CORS must be enabled, but this unlocks direct, on-device AI infrastructure support.
+- **🔍 Advanced Domain Filtering for Web Search**: You can now specify which domains should be included or excluded from web searches, refining results for more relevant information retrieval.
+- **🚀 Improved Image Generation Metadata Handling**: Generated images now retain metadata for better organization and future retrieval.
+- **📂 S3 Key Prefix Support**: Fine-grained control over S3 storage file structuring with configurable key prefixes.
+- **📸 Support for Image-Only Messages**: Send messages containing only images, facilitating more visual-centric interactions.
+- **🌍 Updated Translations**: German, Spanish, Traditional Chinese, and Catalan translations updated for better multilingual support.
+
+### Fixed
+
+- **🔧 OAuth Debug Logs & Username Claim Fixes**: Debug logs have been added for OAuth role and group management, with fixes ensuring proper OAuth username retrieval and claim handling.
+- **📌 Citations Formatting & Toggle Fixes**: Inline citation toggles now function correctly, and citations with more than three sources are now fully visible when expanded.
+- **📸 ComfyUI Maximum Seed Value Constraint Fixed**: The maximum allowed seed value for ComfyUI has been corrected, preventing unintended behavior.
+- **🔑 Connection Settings Stability**: Addressed connection settings issues that were causing instability when saving configurations.
+- **📂 GGUF Model Upload Stability**: Fixed upload inconsistencies for GGUF models, ensuring reliable local model handling.
+- **🔧 Web Search Configuration Bug**: Fixed issues where web search filters and settings weren't correctly applied.
+- **💾 User Settings Persistence Fix**: Ensured user-specific settings are correctly saved and applied across sessions.
+- **🔄 OpenID Username Retrieval Enhancement**: Usernames are now correctly picked up and assigned for OpenID Connect (OIDC) logins.
+
+### Changed
+
+- **🔗 Improved Direct Connections Integration**: Simplified the configuration process for setting up direct API connections, making it easier to integrate custom inference endpoints.
+
## [0.5.10] - 2025-02-05
### Fixed
diff --git a/README.md b/README.md
index 0fb03537df..56ab09b05d 100644
--- a/README.md
+++ b/README.md
@@ -174,7 +174,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
In the last part of the command, replace `open-webui` with your container name if it is different.
-Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
+Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
### Using the Dev Branch 🌙
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index bf6f1d0256..ff298dc5b9 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -660,6 +660,7 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
+S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
@@ -682,6 +683,17 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
+####################################
+# DIRECT CONNECTIONS
+####################################
+
+ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
+ "ENABLE_DIRECT_CONNECTIONS",
+ "direct.enable",
+ os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
+)
+
####################################
# OLLAMA_BASE_URL
####################################
@@ -1325,6 +1337,54 @@ Your task is to synthesize these responses into a single, high-quality response.
Responses from models: {{responses}}"""
+####################################
+# Code Interpreter
+####################################
+
+ENABLE_CODE_INTERPRETER = PersistentConfig(
+ "ENABLE_CODE_INTERPRETER",
+ "code_interpreter.enable",
+ os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
+)
+
+CODE_INTERPRETER_ENGINE = PersistentConfig(
+ "CODE_INTERPRETER_ENGINE",
+ "code_interpreter.engine",
+ os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
+)
+
+CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
+ "CODE_INTERPRETER_PROMPT_TEMPLATE",
+ "code_interpreter.prompt_template",
+ os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_URL",
+ "code_interpreter.jupyter.url",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH",
+ "code_interpreter.jupyter.auth",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
+ "code_interpreter.jupyter.auth_token",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
+)
+
+
+CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
+ "code_interpreter.jupyter.auth_password",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
+)
+
+
DEFAULT_CODE_INTERPRETER_PROMPT = """
#### Tools Available
@@ -1335,9 +1395,8 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
- - If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
+ - **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
- - **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
@@ -1645,7 +1704,7 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
- "rag.rag.web.search.domain.filter_list",
+ "rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
@@ -1690,6 +1749,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
)
+BOCHA_SEARCH_API_KEY = PersistentConfig(
+ "BOCHA_SEARCH_API_KEY",
+ "rag.web.search.bocha_search_api_key",
+ os.getenv("BOCHA_SEARCH_API_KEY", ""),
+)
+
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
@@ -2012,6 +2077,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
+# Add Deepgram configuration
+DEEPGRAM_API_KEY = PersistentConfig(
+ "DEEPGRAM_API_KEY",
+ "audio.stt.deepgram.api_key",
+ os.getenv("DEEPGRAM_API_KEY", ""),
+)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 00605e15dc..0be3887f82 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -92,6 +92,7 @@ log_sources = [
"RAG",
"WEBHOOK",
"SOCKET",
+ "OAUTH",
]
SRC_LOG_LEVELS = {}
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 863f58dea5..88b5b3f692 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -97,6 +97,16 @@ from open_webui.config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
+ # Direct Connections
+ ENABLE_DIRECT_CONNECTIONS,
+ # Code Interpreter
+ ENABLE_CODE_INTERPRETER,
+ CODE_INTERPRETER_ENGINE,
+ CODE_INTERPRETER_PROMPT_TEMPLATE,
+ CODE_INTERPRETER_JUPYTER_URL,
+ CODE_INTERPRETER_JUPYTER_AUTH,
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@@ -130,6 +140,7 @@ from open_webui.config import (
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
WHISPER_MODEL,
+ DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
# Retrieval
@@ -180,6 +191,7 @@ from open_webui.config import (
EXA_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
+ BOCHA_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
@@ -322,7 +334,11 @@ class SPAStaticFiles(StaticFiles):
return await super().get_response(path, scope)
except (HTTPException, StarletteHTTPException) as ex:
if ex.status_code == 404:
- return await super().get_response("index.html", scope)
+ if path.endswith(".js"):
+ # Return 404 for javascript files
+ raise ex
+ else:
+ return await super().get_response("index.html", scope)
else:
raise ex
@@ -389,6 +405,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.OPENAI_MODELS = {}
+########################################
+#
+# DIRECT CONNECTIONS
+#
+########################################
+
+app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
+
########################################
#
# WEBUI
@@ -514,6 +538,7 @@ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
+app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
@@ -569,6 +594,24 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
+########################################
+#
+# CODE INTERPRETER
+#
+########################################
+
+app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
+app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
+app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
+
+app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
+)
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
+)
########################################
#
@@ -611,6 +654,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
+app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
@@ -753,6 +797,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
+
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
@@ -855,20 +900,30 @@ async def chat_completion(
if not request.app.state.MODELS:
await get_all_models(request)
+ model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
- try:
- model_id = form_data.get("model", None)
- if model_id not in request.app.state.MODELS:
- raise Exception("Model not found")
- model = request.app.state.MODELS[model_id]
- model_info = Models.get_model_by_id(model_id)
- # Check if user has access to the model
- if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
- try:
- check_model_access(user, model)
- except Exception as e:
- raise e
+ try:
+ if not model_item.get("direct", False):
+ model_id = form_data.get("model", None)
+ if model_id not in request.app.state.MODELS:
+ raise Exception("Model not found")
+
+ model = request.app.state.MODELS[model_id]
+ model_info = Models.get_model_by_id(model_id)
+
+ # Check if user has access to the model
+ if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
+ try:
+ check_model_access(user, model)
+ except Exception as e:
+ raise e
+ else:
+ model = model_item
+ model_info = None
+
+ request.state.direct = True
+ request.state.model = model
metadata = {
"user_id": user.id,
@@ -880,6 +935,7 @@ async def chat_completion(
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model_info,
+ "direct": model_item.get("direct", False),
**(
{"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native"
@@ -891,6 +947,8 @@ async def chat_completion(
else {}
),
}
+
+ request.state.metadata = metadata
form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload(
@@ -898,6 +956,7 @@ async def chat_completion(
)
except Exception as e:
+ log.debug(f"Error processing chat payload: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
@@ -926,6 +985,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
try:
+ model_item = form_data.pop("model_item", {})
+
+ if model_item.get("direct", False):
+ request.state.direct = True
+ request.state.model = model_item
+
return await chat_completed_handler(request, form_data, user)
except Exception as e:
raise HTTPException(
@@ -939,6 +1004,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
):
try:
+ model_item = form_data.pop("model_item", {})
+
+ if model_item.get("direct", False):
+ request.state.direct = True
+ request.state.model = model_item
+
return await chat_action_handler(request, action_id, form_data, user)
except Exception as e:
raise HTTPException(
@@ -1011,14 +1082,17 @@ async def get_app_config(request: Request):
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
**(
{
+ "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
- "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
+ "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
+ "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
+ "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
}
if user is not None
else {}
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 73ff6c102d..9e0a5865e9 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -470,7 +470,7 @@ class ChatTable:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
- # we check if the chat is still shared by checkng if a chat with the share_id exists
+ # we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py
index 5c196281f7..605299528d 100644
--- a/backend/open_webui/models/users.py
+++ b/backend/open_webui/models/users.py
@@ -271,6 +271,24 @@ class UsersTable:
except Exception:
return None
+ def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
+ try:
+ with get_db() as db:
+ user_settings = db.query(User).filter_by(id=id).first().settings
+
+ if user_settings is None:
+ user_settings = {}
+
+ user_settings.update(updated)
+
+ db.query(User).filter_by(id=id).update({"settings": user_settings})
+ db.commit()
+
+ user = db.query(User).filter_by(id=id).first()
+ return UserModel.model_validate(user)
+ except Exception:
+ return None
+
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index b3d8b5eb8a..b8186b3f93 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -113,6 +113,34 @@ class OpenSearchClient:
return self._result_to_search_result(result)
+ def query(
+ self, collection_name: str, filter: dict, limit: Optional[int] = None
+ ) -> Optional[GetResult]:
+ if not self.has_collection(collection_name):
+ return None
+
+ query_body = {
+ "query": {"bool": {"filter": []}},
+ "_source": ["text", "metadata"],
+ }
+
+ for field, value in filter.items():
+ query_body["query"]["bool"]["filter"].append({"term": {field: value}})
+
+ size = limit if limit else 10
+
+ try:
+ result = self.client.search(
+ index=f"{self.index_prefix}_{collection_name}",
+ body=query_body,
+ size=size,
+ )
+
+ return self._result_to_get_result(result)
+
+ except Exception as e:
+ return None
+
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py
new file mode 100644
index 0000000000..f26da36f84
--- /dev/null
+++ b/backend/open_webui/retrieval/web/bocha.py
@@ -0,0 +1,65 @@
+import logging
+from typing import Optional
+
+import requests
+import json
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def _parse_response(response):
+ result = {}
+ if "data" in response:
+ data = response["data"]
+ if "webPages" in data:
+ webPages = data["webPages"]
+ if "value" in webPages:
+ result["webpage"] = [
+ {
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "url": item.get("url", ""),
+ "snippet": item.get("snippet", ""),
+ "summary": item.get("summary", ""),
+ "siteName": item.get("siteName", ""),
+ "siteIcon": item.get("siteIcon", ""),
+ "datePublished": item.get("datePublished", "")
+ or item.get("dateLastCrawled", ""),
+ }
+ for item in webPages["value"]
+ ]
+ return result
+
+
+def search_bocha(
+ api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
+) -> list[SearchResult]:
+ """Search using Bocha's Search API and return the results as a list of SearchResult objects.
+
+ Args:
+ api_key (str): A Bocha Search API key
+ query (str): The query to search for
+ """
+ url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+
+ payload = json.dumps(
+ {"query": query, "summary": True, "freshness": "noLimit", "count": count}
+ )
+
+ response = requests.post(url, headers=headers, data=payload, timeout=5)
+ response.raise_for_status()
+ results = _parse_response(response.json())
+ print(results)
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
+
+ return [
+ SearchResult(
+ link=result["url"], title=result.get("name"), snippet=result.get("summary")
+ )
+ for result in results.get("webpage", [])[:count]
+ ]
diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py
index 2c51dd3c99..2d2b863b42 100644
--- a/backend/open_webui/retrieval/web/google_pse.py
+++ b/backend/open_webui/retrieval/web/google_pse.py
@@ -17,34 +17,53 @@ def search_google_pse(
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
+ Handles pagination for counts greater than 10.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
+ count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
+ filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
+
+ Returns:
+ list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
-
headers = {"Content-Type": "application/json"}
- params = {
- "cx": search_engine_id,
- "q": query,
- "key": api_key,
- "num": count,
- }
+ all_results = []
+ start_index = 1 # Google PSE start parameter is 1-based
- response = requests.request("GET", url, headers=headers, params=params)
- response.raise_for_status()
+ while count > 0:
+ num_results_this_page = min(count, 10) # Google PSE max results per page is 10
+ params = {
+ "cx": search_engine_id,
+ "q": query,
+ "key": api_key,
+ "num": num_results_this_page,
+ "start": start_index,
+ }
+ response = requests.request("GET", url, headers=headers, params=params)
+ response.raise_for_status()
+ json_response = response.json()
+ results = json_response.get("items", [])
+ if results: # check if results are returned. If not, no more pages to fetch.
+ all_results.extend(results)
+ count -= len(
+ results
+ ) # Decrement count by the number of results fetched in this page.
+ start_index += 10 # Increment start index for the next page
+ else:
+ break # No more results from Google PSE, break the loop
- json_response = response.json()
- results = json_response.get("items", [])
if filter_list:
- results = get_filtered_results(results, filter_list)
+ all_results = get_filtered_results(all_results, filter_list)
+
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
- for result in results
+ for result in all_results
]
diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py
index 3de6c18077..a87293db5c 100644
--- a/backend/open_webui/retrieval/web/jina_search.py
+++ b/backend/open_webui/retrieval/web/jina_search.py
@@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
- headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
- url = str(URL(jina_search_endpoint + query))
- response = requests.get(url, headers=headers)
+
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": api_key,
+ "X-Retain-Images": "none",
+ }
+
+ payload = {"q": query, "count": count if count <= 10 else 10}
+
+ url = str(URL(jina_search_endpoint))
+ response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
- for result in data["data"][:count]:
+ for result in data["data"]:
results.append(
SearchResult(
link=result["url"],
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index c1b15772bd..e2d05ba908 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
+import mimetypes
from fastapi import (
Depends,
@@ -138,6 +139,7 @@ class STTConfigForm(BaseModel):
ENGINE: str
MODEL: str
WHISPER_MODEL: str
+ DEEPGRAM_API_KEY: str
class AudioConfigUpdateForm(BaseModel):
@@ -165,6 +167,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+ "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -190,6 +193,7 @@ async def update_audio_config(
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+ request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -214,6 +218,7 @@ async def update_audio_config(
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+ "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -521,6 +526,69 @@ def transcribe(request: Request, file_path):
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+ elif request.app.state.config.STT_ENGINE == "deepgram":
+ try:
+ # Determine the MIME type of the file
+ mime, _ = mimetypes.guess_type(file_path)
+ if not mime:
+ mime = "audio/wav" # fallback to wav if undetectable
+
+ # Read the audio file
+ with open(file_path, "rb") as f:
+ file_data = f.read()
+
+ # Build headers and parameters
+ headers = {
+ "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
+ "Content-Type": mime,
+ }
+
+ # Add model if specified
+ params = {}
+ if request.app.state.config.STT_MODEL:
+ params["model"] = request.app.state.config.STT_MODEL
+
+ # Make request to Deepgram API
+ r = requests.post(
+ "https://api.deepgram.com/v1/listen",
+ headers=headers,
+ params=params,
+ data=file_data,
+ )
+ r.raise_for_status()
+ response_data = r.json()
+
+ # Extract transcript from Deepgram response
+ try:
+ transcript = response_data["results"]["channels"][0]["alternatives"][
+ 0
+ ].get("transcript", "")
+ except (KeyError, IndexError) as e:
+ log.error(f"Malformed response from Deepgram: {str(e)}")
+ raise Exception(
+ "Failed to parse Deepgram response - unexpected response format"
+ )
+ data = {"text": transcript.strip()}
+
+ # Save transcript
+ transcript_file = f"{file_dir}/{id}.json"
+ with open(transcript_file, "w") as f:
+ json.dump(data, f)
+
+ return data
+
+ except Exception as e:
+ log.exception(e)
+ detail = None
+ if r is not None:
+ try:
+ res = r.json()
+ if "error" in res:
+ detail = f"External: {res['error'].get('message', '')}"
+ except Exception:
+ detail = f"External: {e}"
+ raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index ef6c4d8c1f..016075234a 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -36,6 +36,98 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
+############################
+# Direct Connections Config
+############################
+
+
+class DirectConnectionsConfigForm(BaseModel):
+ ENABLE_DIRECT_CONNECTIONS: bool
+
+
+@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
+async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+ return {
+ "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ }
+
+
+@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
+async def set_direct_connections_config(
+ request: Request,
+ form_data: DirectConnectionsConfigForm,
+ user=Depends(get_admin_user),
+):
+ request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
+ form_data.ENABLE_DIRECT_CONNECTIONS
+ )
+ return {
+ "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ }
+
+
+############################
+# CodeInterpreterConfig
+############################
+class CodeInterpreterConfigForm(BaseModel):
+ ENABLE_CODE_INTERPRETER: bool
+ CODE_INTERPRETER_ENGINE: str
+ CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
+ CODE_INTERPRETER_JUPYTER_URL: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
+
+
+@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
+async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
+ return {
+ "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
+ "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
+ "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
+ "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
+ "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
+ }
+
+
+@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
+async def set_code_interpreter_config(
+ request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
+):
+ request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
+ request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
+ request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
+ form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
+ form_data.CODE_INTERPRETER_JUPYTER_URL
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
+ )
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
+ )
+
+ return {
+ "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
+ "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
+ "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
+ "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
+ "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
+ }
+
+
############################
# SetDefaultModels
############################
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index 7160c2e86e..0513212571 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -3,30 +3,22 @@ import os
import uuid
from pathlib import Path
from typing import Optional
-from pydantic import BaseModel
-import mimetypes
from urllib.parse import quote
-from open_webui.storage.provider import Storage
-
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
+from fastapi.responses import FileResponse, StreamingResponse
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
-from open_webui.routers.retrieval import process_file, ProcessFileForm
-
-from open_webui.config import UPLOAD_DIR
-from open_webui.env import SRC_LOG_LEVELS
-from open_webui.constants import ERROR_MESSAGES
-
-
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
-from fastapi.responses import FileResponse, StreamingResponse
-
-
+from open_webui.routers.retrieval import ProcessFileForm, process_file
+from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
+from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -41,7 +33,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(
- request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
+ request: Request,
+ file: UploadFile = File(...),
+ user=Depends(get_verified_user),
+ file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
try:
@@ -65,6 +60,7 @@ def upload_file(
"name": name,
"content_type": file.content_type,
"size": len(contents),
+ "data": file_metadata,
},
}
),
@@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files()
except Exception as e:
log.exception(e)
- log.error(f"Error deleting files")
+ log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
- log.error(f"Error getting file content")
+ log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
- log.error(f"Error getting file content")
+ log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
- log.error(f"Error deleting files")
+ log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 7afd9d106d..4046773dea 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -1,32 +1,26 @@
import asyncio
import base64
+import io
import json
import logging
import mimetypes
import re
-import uuid
from pathlib import Path
from typing import Optional
import requests
-
-
-from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-
-
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
-
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
+from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
-
+from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
-
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
@@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None
-def save_b64_image(b64_str):
+def load_b64_image_data(b64_str):
try:
- image_id = str(uuid.uuid4())
-
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
-
img_data = base64.b64decode(encoded)
- image_format = mimetypes.guess_extension(mime_type)
-
- image_filename = f"{image_id}{image_format}"
- file_path = IMAGE_CACHE_DIR / f"{image_filename}"
- with open(file_path, "wb") as f:
- f.write(img_data)
- return image_filename
else:
- image_filename = f"{image_id}.png"
- file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
-
+ mime_type = "image/png"
img_data = base64.b64decode(b64_str)
-
- # Write the image data to a file
- with open(file_path, "wb") as f:
- f.write(img_data)
- return image_filename
-
+ return img_data, mime_type
except Exception as e:
- log.exception(f"Error saving image: {e}")
+ log.exception(f"Error loading image data: {e}")
return None
-def save_url_image(url, headers=None):
- image_id = str(uuid.uuid4())
+def load_url_image_data(url, headers=None):
try:
if headers:
r = requests.get(url, headers=headers)
@@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
- image_format = mimetypes.guess_extension(mime_type)
-
- if not image_format:
- raise ValueError("Could not determine image type from MIME type")
-
- image_filename = f"{image_id}{image_format}"
-
- file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
- with open(file_path, "wb") as image_file:
- for chunk in r.iter_content(chunk_size=8192):
- image_file.write(chunk)
- return image_filename
+ return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
@@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
return None
+def upload_image(request, image_metadata, image_data, content_type, user):
+ image_format = mimetypes.guess_extension(content_type)
+ file = UploadFile(
+ file=io.BytesIO(image_data),
+ filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
+ headers={
+ "content-type": content_type,
+ },
+ )
+ file_item = upload_file(request, file, user, file_metadata=image_metadata)
+ url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+ return url
+
+
@router.post("/generations")
async def image_generations(
request: Request,
@@ -500,13 +478,9 @@ async def image_generations(
images = []
for image in res["data"]:
- image_filename = save_b64_image(image["b64_json"])
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump(data, f)
-
+ image_data, content_type = load_b64_image_data(image["b64_json"])
+ url = upload_image(request, data, image_data, content_type, user)
+ images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -552,14 +526,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
- image_filename = save_url_image(image["url"], headers)
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump(form_data.model_dump(exclude_none=True), f)
-
- log.debug(f"images: {images}")
+ image_data, content_type = load_url_image_data(image["url"], headers)
+ url = upload_image(
+ request,
+ form_data.model_dump(exclude_none=True),
+ image_data,
+ content_type,
+ user,
+ )
+ images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@@ -604,13 +579,15 @@ async def image_generations(
images = []
for image in res["images"]:
- image_filename = save_b64_image(image)
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump({**data, "info": res["info"]}, f)
-
+ image_data, content_type = load_b64_image_data(image)
+ url = upload_image(
+ request,
+ {**data, "info": res["info"]},
+ image_data,
+ content_type,
+ user,
+ )
+ images.append({"url": url})
return images
except Exception as e:
error = e
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 2ab06eb95e..64373c616c 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -11,10 +11,8 @@ import re
import time
from typing import Optional, Union
from urllib.parse import urlparse
-
import aiohttp
from aiocache import cached
-
import requests
from fastapi import (
@@ -990,6 +988,8 @@ async def generate_chat_completion(
)
payload = {**form_data.model_dump(exclude_none=True)}
+ if "metadata" in payload:
+ del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
@@ -1408,9 +1408,10 @@ async def download_model(
return None
+# TODO: Progress bar does not reflect size & duration of upload.
@router.post("/models/upload")
@router.post("/models/upload/{url_idx}")
-def upload_model(
+async def upload_model(
request: Request,
file: UploadFile = File(...),
url_idx: Optional[int] = None,
@@ -1419,59 +1420,85 @@ def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
- file_path = f"{UPLOAD_DIR}/{file.filename}"
+ # --- P1: save file locally ---
+ chunk_size = 1024 * 1024 * 2 # 2 MB chunks
+ with open(file_path, "wb") as out_f:
+ while True:
+ chunk = file.file.read(chunk_size)
+ # log.info(f"Chunk: {str(chunk)}") # DEBUG
+ if not chunk:
+ break
+ out_f.write(chunk)
- # Save file in chunks
- with open(file_path, "wb+") as f:
- for chunk in file.file:
- f.write(chunk)
-
- def file_process_stream():
+ async def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
- chunk_size = 1024 * 1024
+ log.info(f"Total Model Size: {str(total_size)}") # DEBUG
+
+ # --- P2: SSE progress + calculate sha256 hash ---
+ file_hash = calculate_sha256(file_path, chunk_size)
+ log.info(f"Model Hash: {str(file_hash)}") # DEBUG
try:
with open(file_path, "rb") as f:
- total = 0
- done = False
-
- while not done:
- chunk = f.read(chunk_size)
- if not chunk:
- done = True
- continue
-
- total += len(chunk)
- progress = round((total / total_size) * 100, 2)
-
- res = {
+ bytes_read = 0
+ while chunk := f.read(chunk_size):
+ bytes_read += len(chunk)
+ progress = round(bytes_read / total_size * 100, 2)
+ data_msg = {
"progress": progress,
"total": total_size,
- "completed": total,
+ "completed": bytes_read,
}
- yield f"data: {json.dumps(res)}\n\n"
+ yield f"data: {json.dumps(data_msg)}\n\n"
- if done:
- f.seek(0)
- hashed = calculate_sha256(f)
- f.seek(0)
+ # --- P3: Upload to ollama /api/blobs ---
+ with open(file_path, "rb") as f:
+ url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
+ response = requests.post(url, data=f)
- url = f"{ollama_url}/api/blobs/sha256:{hashed}"
- response = requests.post(url, data=f)
+ if response.ok:
+ log.info(f"Uploaded to /api/blobs") # DEBUG
+ # Remove local file
+ os.remove(file_path)
- if response.ok:
- res = {
- "done": done,
- "blob": f"sha256:{hashed}",
- "name": file.filename,
- }
- os.remove(file_path)
- yield f"data: {json.dumps(res)}\n\n"
- else:
- raise Exception(
- "Ollama: Could not create blob, Please try again."
- )
+ # Create model in ollama
+ model_name, ext = os.path.splitext(file.filename)
+ log.info(f"Created Model: {model_name}") # DEBUG
+
+ create_payload = {
+ "model": model_name,
+ # Reference the file by its original name => the uploaded blob's digest
+ "files": {file.filename: f"sha256:{file_hash}"},
+ }
+ log.info(f"Model Payload: {create_payload}") # DEBUG
+
+ # Call ollama /api/create
+ # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
+ create_resp = requests.post(
+ url=f"{ollama_url}/api/create",
+ headers={"Content-Type": "application/json"},
+ data=json.dumps(create_payload),
+ )
+
+ if create_resp.ok:
+ log.info(f"API SUCCESS!") # DEBUG
+ done_msg = {
+ "done": True,
+ "blob": f"sha256:{file_hash}",
+ "name": file.filename,
+ "model_created": model_name,
+ }
+ yield f"data: {json.dumps(done_msg)}\n\n"
+ else:
+ raise Exception(
+ f"Failed to create model in Ollama. {create_resp.text}"
+ )
+
+ else:
+ raise Exception("Ollama: Could not create blob, Please try again.")
except Exception as e:
res = {"error": str(e)}
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index d18f2a8ffc..afda362373 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -75,9 +75,9 @@ async def cleanup_response(
await session.close()
-def openai_o1_handler(payload):
+def openai_o1_o3_handler(payload):
"""
- Handle O1 specific parameters
+ Handle o1, o3 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
@@ -621,10 +621,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
- # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
- is_o1 = payload["model"].lower().startswith("o1-")
- if is_o1:
- payload = openai_o1_handler(payload)
+ # Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
+ is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
+ if is_o1_o3:
+ payload = openai_o1_o3_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 77f04a4be5..e4bab52898 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -45,6 +45,7 @@ from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
+from open_webui.retrieval.web.bocha import search_bocha
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
from open_webui.retrieval.web.google_pse import search_google_pse
from open_webui.retrieval.web.jina_search import search_jina
@@ -379,6 +380,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
+ "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -392,6 +394,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+ "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@@ -428,6 +431,7 @@ class WebSearchConfig(BaseModel):
brave_search_api_key: Optional[str] = None
kagi_search_api_key: Optional[str] = None
mojeek_search_api_key: Optional[str] = None
+ bocha_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
@@ -441,6 +445,7 @@ class WebSearchConfig(BaseModel):
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
+ domain_filter_list: Optional[List[str]] = []
class WebConfig(BaseModel):
@@ -523,6 +528,9 @@ async def update_rag_config(
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
form_data.web.search.mojeek_search_api_key
)
+ request.app.state.config.BOCHA_SEARCH_API_KEY = (
+ form_data.web.search.bocha_search_api_key
+ )
request.app.state.config.SERPSTACK_API_KEY = (
form_data.web.search.serpstack_api_key
)
@@ -553,6 +561,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
+ request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
+ form_data.web.search.domain_filter_list
+ )
return {
"status": True,
@@ -586,6 +597,7 @@ async def update_rag_config(
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
+ "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -599,6 +611,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+ "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@@ -1107,6 +1120,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- BRAVE_SEARCH_API_KEY
- KAGI_SEARCH_API_KEY
- MOJEEK_SEARCH_API_KEY
+ - BOCHA_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
@@ -1174,6 +1188,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
+ elif engine == "bocha":
+ if request.app.state.config.BOCHA_SEARCH_API_KEY:
+ return search_bocha(
+ request.app.state.config.BOCHA_SEARCH_API_KEY,
+ query,
+ request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if request.app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py
index f56a0232dd..91ec8e9723 100644
--- a/backend/open_webui/routers/tasks.py
+++ b/backend/open_webui/routers/tasks.py
@@ -139,7 +139,12 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -198,6 +203,7 @@ async def generate_title(
}
),
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -225,7 +231,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -261,6 +272,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -281,7 +293,12 @@ async def generate_chat_tags(
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -321,6 +338,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -356,7 +374,12 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -392,6 +415,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -431,7 +455,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -467,6 +496,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -488,7 +518,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -531,7 +566,11 @@ async def generate_emoji(
}
),
"chat_id": form_data.get("chat_id", None),
- "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.EMOJI_GENERATION),
+ "task_body": form_data,
+ },
}
try:
@@ -548,7 +587,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
- models = request.app.state.MODELS
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
model_id = form_data["model"]
if model_id not in models:
@@ -581,6 +626,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py
index ddcaef7674..872212d3ce 100644
--- a/backend/open_webui/routers/users.py
+++ b/backend/open_webui/routers/users.py
@@ -153,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
- user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
+ user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
if user:
return user.settings
else:
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
index 3788139eaa..6f59151227 100644
--- a/backend/open_webui/socket/main.py
+++ b/backend/open_webui/socket/main.py
@@ -279,8 +279,8 @@ def get_event_emitter(request_info):
await sio.emit(
"chat-events",
{
- "chat_id": request_info["chat_id"],
- "message_id": request_info["message_id"],
+ "chat_id": request_info.get("chat_id", None),
+ "message_id": request_info.get("message_id", None),
"data": event_data,
},
to=session_id,
@@ -329,8 +329,8 @@ def get_event_call(request_info):
response = await sio.call(
"chat-events",
{
- "chat_id": request_info["chat_id"],
- "message_id": request_info["message_id"],
+ "chat_id": request_info.get("chat_id", None),
+ "message_id": request_info.get("message_id", None),
"data": event_data,
},
to=request_info["session_id"],
diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py
index 0c0a8aacfc..b03cf0a7ec 100644
--- a/backend/open_webui/storage/provider.py
+++ b/backend/open_webui/storage/provider.py
@@ -10,6 +10,7 @@ from open_webui.config import (
S3_ACCESS_KEY_ID,
S3_BUCKET_NAME,
S3_ENDPOINT_URL,
+ S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
GCS_BUCKET_NAME,
@@ -93,15 +94,17 @@ class S3StorageProvider(StorageProvider):
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
+ self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
- self.s3_client.upload_file(file_path, self.bucket_name, filename)
+ s3_key = os.path.join(self.key_prefix, filename)
+ self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
return (
open(file_path, "rb").read(),
- "s3://" + self.bucket_name + "/" + filename,
+ "s3://" + self.bucket_name + "/" + s3_key,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
@@ -109,18 +112,18 @@ class S3StorageProvider(StorageProvider):
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
- bucket_name, key = file_path.split("//")[1].split("/")
- local_file_path = f"{UPLOAD_DIR}/{key}"
- self.s3_client.download_file(bucket_name, key, local_file_path)
+ s3_key = self._extract_s3_key(file_path)
+ local_file_path = self._get_local_file_path(s3_key)
+ self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from S3 storage."""
- filename = file_path.split("/")[-1]
try:
- self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+ s3_key = self._extract_s3_key(file_path)
+ self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
@@ -133,6 +136,10 @@ class S3StorageProvider(StorageProvider):
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
+ # Skip objects that were not uploaded from open-webui in the first place
+ if not content["Key"].startswith(self.key_prefix):
+ continue
+
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
@@ -142,6 +149,13 @@ class S3StorageProvider(StorageProvider):
# Always delete from local storage
LocalStorageProvider.delete_all_files()
+ # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
+ def _extract_s3_key(self, full_file_path: str) -> str:
+ return "/".join(full_file_path.split("//")[1].split("/")[1:])
+
+ def _get_local_file_path(self, s3_key: str) -> str:
+ return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
+
class GCSStorageProvider(StorageProvider):
def __init__(self):
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 0719f6af5b..253eaedfb9 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -7,14 +7,17 @@ from typing import Any, Optional
import random
import json
import inspect
+import uuid
+import asyncio
-from fastapi import Request
-from starlette.responses import Response, StreamingResponse
+from fastapi import Request, status
+from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.models.users import UserModel
from open_webui.socket.main import (
+ sio,
get_event_call,
get_event_emitter,
)
@@ -44,6 +47,10 @@ from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
+from open_webui.utils.filter import (
+ get_sorted_filter_ids,
+ process_filter_functions,
+)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
@@ -53,6 +60,101 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
+async def generate_direct_chat_completion(
+ request: Request,
+ form_data: dict,
+ user: Any,
+ models: dict,
+):
+ print("generate_direct_chat_completion")
+
+ metadata = form_data.pop("metadata", {})
+
+ user_id = metadata.get("user_id")
+ session_id = metadata.get("session_id")
+ request_id = str(uuid.uuid4()) # Generate a unique request ID
+
+ event_caller = get_event_call(metadata)
+
+ channel = f"{user_id}:{session_id}:{request_id}"
+
+ if form_data.get("stream"):
+ q = asyncio.Queue()
+
+ async def message_listener(sid, data):
+ """
+ Handle received socket messages and push them into the queue.
+ """
+ await q.put(data)
+
+ # Register the listener
+ sio.on(channel, message_listener)
+
+ # Start processing chat completion in background
+ res = await event_caller(
+ {
+ "type": "request:chat:completion",
+ "data": {
+ "form_data": form_data,
+ "model": models[form_data["model"]],
+ "channel": channel,
+ "session_id": session_id,
+ },
+ }
+ )
+
+ print("res", res)
+
+ if res.get("status", False):
+ # Define a generator to stream responses
+ async def event_generator():
+ nonlocal q
+ try:
+ while True:
+ data = await q.get() # Wait for new messages
+ if isinstance(data, dict):
+ if "done" in data and data["done"]:
+ break # Stop streaming when 'done' is received
+
+ yield f"data: {json.dumps(data)}\n\n"
+ elif isinstance(data, str):
+ yield data
+ except Exception as e:
+ log.debug(f"Error in event generator: {e}")
+ pass
+
+ # Define a background task to run the event generator
+ async def background():
+ try:
+ del sio.handlers["/"][channel]
+ except Exception as e:
+ pass
+
+ # Return the streaming response
+ return StreamingResponse(
+ event_generator(), media_type="text/event-stream", background=background
+ )
+ else:
+ raise Exception(str(res))
+ else:
+ res = await event_caller(
+ {
+ "type": "request:chat:completion",
+ "data": {
+ "form_data": form_data,
+ "model": models[form_data["model"]],
+ "channel": channel,
+ "session_id": session_id,
+ },
+ }
+ )
+
+ if "error" in res:
+ raise Exception(res["error"])
+
+ return res
+
+
async def generate_chat_completion(
request: Request,
form_data: dict,
@@ -62,7 +164,16 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
- models = request.app.state.MODELS
+ if hasattr(request.state, "metadata"):
+ form_data["metadata"] = request.state.metadata
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ log.debug(f"direct connection to model: {models}")
+ else:
+ models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -83,78 +194,90 @@ async def generate_chat_completion(
except Exception as e:
raise e
- if model["owned_by"] == "arena":
- model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
- filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
- if model_ids and filter_mode == "exclude":
- model_ids = [
- model["id"]
- for model in list(request.app.state.MODELS.values())
- if model.get("owned_by") != "arena" and model["id"] not in model_ids
- ]
-
- selected_model_id = None
- if isinstance(model_ids, list) and model_ids:
- selected_model_id = random.choice(model_ids)
- else:
- model_ids = [
- model["id"]
- for model in list(request.app.state.MODELS.values())
- if model.get("owned_by") != "arena"
- ]
- selected_model_id = random.choice(model_ids)
-
- form_data["model"] = selected_model_id
-
- if form_data.get("stream") == True:
-
- async def stream_wrapper(stream):
- yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
- async for chunk in stream:
- yield chunk
-
- response = await generate_chat_completion(
- request, form_data, user, bypass_filter=True
- )
- return StreamingResponse(
- stream_wrapper(response.body_iterator),
- media_type="text/event-stream",
- background=response.background,
- )
- else:
- return {
- **(
- await generate_chat_completion(
- request, form_data, user, bypass_filter=True
- )
- ),
- "selected_model_id": selected_model_id,
- }
-
- if model.get("pipe"):
- # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
- return await generate_function_chat_completion(
+ if getattr(request.state, "direct", False):
+ return await generate_direct_chat_completion(
request, form_data, user=user, models=models
)
- if model["owned_by"] == "ollama":
- # Using /ollama/api/chat endpoint
- form_data = convert_payload_openai_to_ollama(form_data)
- response = await generate_ollama_chat_completion(
- request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
- )
- if form_data.get("stream"):
- response.headers["content-type"] = "text/event-stream"
- return StreamingResponse(
- convert_streaming_response_ollama_to_openai(response),
- headers=dict(response.headers),
- background=response.background,
- )
- else:
- return convert_response_ollama_to_openai(response)
+
else:
- return await generate_openai_chat_completion(
- request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
- )
+ if model["owned_by"] == "arena":
+ model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
+ filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
+ if model_ids and filter_mode == "exclude":
+ model_ids = [
+ model["id"]
+ for model in list(request.app.state.MODELS.values())
+ if model.get("owned_by") != "arena" and model["id"] not in model_ids
+ ]
+
+ selected_model_id = None
+ if isinstance(model_ids, list) and model_ids:
+ selected_model_id = random.choice(model_ids)
+ else:
+ model_ids = [
+ model["id"]
+ for model in list(request.app.state.MODELS.values())
+ if model.get("owned_by") != "arena"
+ ]
+ selected_model_id = random.choice(model_ids)
+
+ form_data["model"] = selected_model_id
+
+ if form_data.get("stream") == True:
+
+ async def stream_wrapper(stream):
+ yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
+ async for chunk in stream:
+ yield chunk
+
+ response = await generate_chat_completion(
+ request, form_data, user, bypass_filter=True
+ )
+ return StreamingResponse(
+ stream_wrapper(response.body_iterator),
+ media_type="text/event-stream",
+ background=response.background,
+ )
+ else:
+ return {
+ **(
+ await generate_chat_completion(
+ request, form_data, user, bypass_filter=True
+ )
+ ),
+ "selected_model_id": selected_model_id,
+ }
+
+ if model.get("pipe"):
+ # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
+ return await generate_function_chat_completion(
+ request, form_data, user=user, models=models
+ )
+ if model["owned_by"] == "ollama":
+ # Using /ollama/api/chat endpoint
+ form_data = convert_payload_openai_to_ollama(form_data)
+ response = await generate_ollama_chat_completion(
+ request=request,
+ form_data=form_data,
+ user=user,
+ bypass_filter=bypass_filter,
+ )
+ if form_data.get("stream"):
+ response.headers["content-type"] = "text/event-stream"
+ return StreamingResponse(
+ convert_streaming_response_ollama_to_openai(response),
+ headers=dict(response.headers),
+ background=response.background,
+ )
+ else:
+ return convert_response_ollama_to_openai(response)
+ else:
+ return await generate_openai_chat_completion(
+ request=request,
+ form_data=form_data,
+ user=user,
+ bypass_filter=bypass_filter,
+ )
chat_completion = generate_chat_completion
@@ -163,7 +286,13 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
- models = request.app.state.MODELS
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
data = form_data
model_id = data["model"]
@@ -177,116 +306,38 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
except Exception as e:
return Exception(f"Error: {e}")
- __event_emitter__ = get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- "user_id": user.id,
- }
- )
+ metadata = {
+ "chat_id": data["chat_id"],
+ "message_id": data["id"],
+ "session_id": data["session_id"],
+ "user_id": user.id,
+ }
- __event_call__ = get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- "user_id": user.id,
- }
- )
+ extra_params = {
+ "__event_emitter__": get_event_emitter(metadata),
+ "__event_call__": get_event_call(metadata),
+ "__user__": {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ },
+ "__metadata__": metadata,
+ "__request__": request,
+ "__model__": model,
+ }
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel to include vavles
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
-
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
-
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
-
- # Sort filter_ids by priority, using the get_priority function
- filter_ids.sort(key=get_priority)
-
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
-
- if filter_id in request.app.state.FUNCTIONS:
- function_module = request.app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- request.app.state.FUNCTIONS[filter_id] = function_module
-
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
-
- if not hasattr(function_module, "outlet"):
- continue
- try:
- outlet = function_module.outlet
-
- # Get the signature of the function
- sig = inspect.signature(outlet)
- params = {"body": data}
-
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": filter_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- "__request__": request,
- }
-
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
-
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
-
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, user.id
- )
- )
- except Exception as e:
- print(e)
-
- params = {**params, "__user__": __user__}
-
- if inspect.iscoroutinefunction(outlet):
- data = await outlet(**params)
- else:
- data = outlet(**params)
-
- except Exception as e:
- return Exception(f"Error: {e}")
-
- return data
+ try:
+ result, _ = await process_filter_functions(
+ request=request,
+ filter_ids=get_sorted_filter_ids(model),
+ filter_type="outlet",
+ form_data=data,
+ extra_params=extra_params,
+ )
+ return result
+ except Exception as e:
+ return Exception(f"Error: {e}")
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
@@ -301,7 +352,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not request.app.state.MODELS:
await get_all_models(request)
- models = request.app.state.MODELS
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
data = form_data
model_id = data["model"]
diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py
new file mode 100644
index 0000000000..0a74da9c77
--- /dev/null
+++ b/backend/open_webui/utils/code_interpreter.py
@@ -0,0 +1,148 @@
+import asyncio
+import json
+import uuid
+import websockets
+import requests
+from urllib.parse import urljoin
+
+
+async def execute_code_jupyter(
+ jupyter_url, code, token=None, password=None, timeout=10
+):
+ """
+ Executes Python code in a Jupyter kernel.
+ Supports authentication with a token or password.
+ :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
+ :param code: Code to execute
+ :param token: Jupyter authentication token (optional)
+ :param password: Jupyter password (optional)
+ :param timeout: WebSocket timeout in seconds (default: 10s)
+ :return: Dictionary with stdout, stderr, and result
+ - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
+ """
+ session = requests.Session() # Maintain cookies
+ headers = {} # Headers for requests
+
+ # Authenticate using password
+ if password and not token:
+ try:
+ login_url = urljoin(jupyter_url, "/login")
+ response = session.get(login_url)
+ response.raise_for_status()
+ xsrf_token = session.cookies.get("_xsrf")
+ if not xsrf_token:
+ raise ValueError("Failed to fetch _xsrf token")
+
+ login_data = {"_xsrf": xsrf_token, "password": password}
+ login_response = session.post(
+ login_url, data=login_data, cookies=session.cookies
+ )
+ login_response.raise_for_status()
+ headers["X-XSRFToken"] = xsrf_token
+ except Exception as e:
+ return {
+ "stdout": "",
+ "stderr": f"Authentication Error: {str(e)}",
+ "result": "",
+ }
+
+ # Construct API URLs with authentication token if provided
+ params = f"?token={token}" if token else ""
+ kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
+
+ try:
+ response = session.post(kernel_url, headers=headers, cookies=session.cookies)
+ response.raise_for_status()
+ kernel_id = response.json()["id"]
+
+ websocket_url = urljoin(
+ jupyter_url.replace("http", "ws"),
+ f"/api/kernels/{kernel_id}/channels{params}",
+ )
+
+ ws_headers = {}
+ if password and not token:
+ ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
+ cookies = {name: value for name, value in session.cookies.items()}
+ ws_headers["Cookie"] = "; ".join(
+ [f"{name}={value}" for name, value in cookies.items()]
+ )
+
+ async with websockets.connect(
+ websocket_url, additional_headers=ws_headers
+ ) as ws:
+ msg_id = str(uuid.uuid4())
+ execute_request = {
+ "header": {
+ "msg_id": msg_id,
+ "msg_type": "execute_request",
+ "username": "user",
+ "session": str(uuid.uuid4()),
+ "date": "",
+ "version": "5.3",
+ },
+ "parent_header": {},
+ "metadata": {},
+ "content": {
+ "code": code,
+ "silent": False,
+ "store_history": True,
+ "user_expressions": {},
+ "allow_stdin": False,
+ "stop_on_error": True,
+ },
+ "channel": "shell",
+ }
+ await ws.send(json.dumps(execute_request))
+
+ stdout, stderr, result = "", "", []
+
+ while True:
+ try:
+ message = await asyncio.wait_for(ws.recv(), timeout)
+ message_data = json.loads(message)
+ if message_data.get("parent_header", {}).get("msg_id") == msg_id:
+ msg_type = message_data.get("msg_type")
+
+ if msg_type == "stream":
+ if message_data["content"]["name"] == "stdout":
+ stdout += message_data["content"]["text"]
+ elif message_data["content"]["name"] == "stderr":
+ stderr += message_data["content"]["text"]
+
+ elif msg_type in ("execute_result", "display_data"):
+ data = message_data["content"]["data"]
+ if "image/png" in data:
+ result.append(
+ f"data:image/png;base64,{data['image/png']}"
+ )
+ elif "text/plain" in data:
+ result.append(data["text/plain"])
+
+ elif msg_type == "error":
+ stderr += "\n".join(message_data["content"]["traceback"])
+
+ elif (
+ msg_type == "status"
+ and message_data["content"]["execution_state"] == "idle"
+ ):
+ break
+
+ except asyncio.TimeoutError:
+ stderr += "\nExecution timed out."
+ break
+
+ except Exception as e:
+ return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
+
+ finally:
+ if kernel_id:
+ requests.delete(
+ f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
+ )
+
+ return {
+ "stdout": stdout.strip(),
+ "stderr": stderr.strip(),
+ "result": "\n".join(result).strip() if result else "",
+ }
diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py
new file mode 100644
index 0000000000..de51bd46e5
--- /dev/null
+++ b/backend/open_webui/utils/filter.py
@@ -0,0 +1,99 @@
+import inspect
+from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.models.functions import Functions
+
+
+def get_sorted_filter_ids(model):
+ def get_priority(function_id):
+ function = Functions.get_function_by_id(function_id)
+ if function is not None and hasattr(function, "valves"):
+ # TODO: Fix FunctionModel to include vavles
+ return (function.valves if function.valves else {}).get("priority", 0)
+ return 0
+
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+ if "info" in model and "meta" in model["info"]:
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+ filter_ids = list(set(filter_ids))
+
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type("filter", active_only=True)
+ ]
+
+ filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
+ filter_ids.sort(key=get_priority)
+ return filter_ids
+
+
+async def process_filter_functions(
+ request, filter_ids, filter_type, form_data, extra_params
+):
+ skip_files = None
+
+ for filter_id in filter_ids:
+ filter = Functions.get_function_by_id(filter_id)
+ if not filter:
+ continue
+
+ if filter_id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[filter_id]
+ else:
+ function_module, _, _ = load_function_module_by_id(filter_id)
+ request.app.state.FUNCTIONS[filter_id] = function_module
+
+ # Check if the function has a file_handler variable
+ if filter_type == "inlet" and hasattr(function_module, "file_handler"):
+ skip_files = function_module.file_handler
+
+ # Apply valves to the function
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+ valves = Functions.get_function_valves_by_id(filter_id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ # Prepare handler function
+ handler = getattr(function_module, filter_type, None)
+ if not handler:
+ continue
+
+ try:
+ # Prepare parameters
+ sig = inspect.signature(handler)
+ params = {"body": form_data} | {
+ k: v
+ for k, v in {
+ **extra_params,
+ "__id__": filter_id,
+ }.items()
+ if k in sig.parameters
+ }
+
+ # Handle user parameters
+ if "__user__" in sig.parameters:
+ if hasattr(function_module, "UserValves"):
+ try:
+ params["__user__"]["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(
+ filter_id, params["__user__"]["id"]
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ # Execute handler
+ if inspect.iscoroutinefunction(handler):
+ form_data = await handler(**params)
+ else:
+ form_data = handler(**params)
+
+ except Exception as e:
+ print(f"Error in {filter_type} handler {filter_id}: {e}")
+ raise e
+
+ # Handle file cleanup for inlet
+ if skip_files and "files" in form_data.get("metadata", {}):
+ del form_data["metadata"]["files"]
+
+ return form_data, {}
diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py
index 679fff9f64..b86c257591 100644
--- a/backend/open_webui/utils/images/comfyui.py
+++ b/backend/open_webui/utils/images/comfyui.py
@@ -161,7 +161,7 @@ async def comfyui_generate_image(
seed = (
payload.seed
if payload.seed
- else random.randint(0, 18446744073709551614)
+ else random.randint(0, 1125899906842624)
)
for node_id in node.node_ids:
workflow[node_id]["inputs"][node.key] = seed
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 06763483cb..4d70ddd65f 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -68,7 +68,11 @@ from open_webui.utils.misc import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
-
+from open_webui.utils.filter import (
+ get_sorted_filter_ids,
+ process_filter_functions,
+)
+from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.tasks import create_task
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
-async def chat_completion_filter_functions_handler(request, body, model, extra_params):
- skip_files = None
-
- def get_filter_function_ids(model):
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
-
- filter_ids = [
- function.id for function in Functions.get_global_filter_functions()
- ]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
-
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
-
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
-
- filter_ids.sort(key=get_priority)
- return filter_ids
-
- filter_ids = get_filter_function_ids(model)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
-
- if filter_id in request.app.state.FUNCTIONS:
- function_module = request.app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- request.app.state.FUNCTIONS[filter_id] = function_module
-
- # Check if the function has a file_handler variable
- if hasattr(function_module, "file_handler"):
- skip_files = function_module.file_handler
-
- # Apply valves to the function
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
-
- if hasattr(function_module, "inlet"):
- try:
- inlet = function_module.inlet
-
- # Create a dictionary of parameters to be passed to the function
- params = {"body": body} | {
- k: v
- for k, v in {
- **extra_params,
- "__model__": model,
- "__id__": filter_id,
- }.items()
- if k in inspect.signature(inlet).parameters
- }
-
- if "__user__" in params and hasattr(function_module, "UserValves"):
- try:
- params["__user__"]["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, params["__user__"]["id"]
- )
- )
- except Exception as e:
- print(e)
-
- if inspect.iscoroutinefunction(inlet):
- body = await inlet(**params)
- else:
- body = inlet(**params)
-
- except Exception as e:
- print(f"Error: {e}")
- raise e
-
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
-
- return body, {}
-
-
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
@@ -572,13 +483,13 @@ async def chat_image_generation_handler(
{
"type": "status",
"data": {
- "description": f"An error occured while generating an image",
+ "description": f"An error occurred while generating an image",
"done": True,
},
}
)
- system_message_content = "