diff --git a/README.md b/README.md
index 0fb03537df..56ab09b05d 100644
--- a/README.md
+++ b/README.md
@@ -174,7 +174,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
In the last part of the command, replace `open-webui` with your container name if it is different.
-Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
+Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
### Using the Dev Branch 🌙
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index bf6f1d0256..ff298dc5b9 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -660,6 +660,7 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
+S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
@@ -682,6 +683,17 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
+####################################
+# DIRECT CONNECTIONS
+####################################
+
+ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
+ "ENABLE_DIRECT_CONNECTIONS",
+ "direct.enable",
+ os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
+)
+
####################################
# OLLAMA_BASE_URL
####################################
@@ -1325,6 +1337,54 @@ Your task is to synthesize these responses into a single, high-quality response.
Responses from models: {{responses}}"""
+####################################
+# Code Interpreter
+####################################
+
+ENABLE_CODE_INTERPRETER = PersistentConfig(
+ "ENABLE_CODE_INTERPRETER",
+ "code_interpreter.enable",
+ os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
+)
+
+CODE_INTERPRETER_ENGINE = PersistentConfig(
+ "CODE_INTERPRETER_ENGINE",
+ "code_interpreter.engine",
+ os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
+)
+
+CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
+ "CODE_INTERPRETER_PROMPT_TEMPLATE",
+ "code_interpreter.prompt_template",
+ os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_URL",
+ "code_interpreter.jupyter.url",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH",
+ "code_interpreter.jupyter.auth",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
+)
+
+CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
+ "code_interpreter.jupyter.auth_token",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
+)
+
+
+CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
+ "code_interpreter.jupyter.auth_password",
+ os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
+)
+
+
DEFAULT_CODE_INTERPRETER_PROMPT = """
#### Tools Available
@@ -1335,9 +1395,8 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
- - If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
+ - **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
- - **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
@@ -1645,7 +1704,7 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
- "rag.rag.web.search.domain.filter_list",
+ "rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
@@ -1690,6 +1749,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
)
+BOCHA_SEARCH_API_KEY = PersistentConfig(
+ "BOCHA_SEARCH_API_KEY",
+ "rag.web.search.bocha_search_api_key",
+ os.getenv("BOCHA_SEARCH_API_KEY", ""),
+)
+
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
@@ -2012,6 +2077,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
+# Add Deepgram configuration
+DEEPGRAM_API_KEY = PersistentConfig(
+ "DEEPGRAM_API_KEY",
+ "audio.stt.deepgram.api_key",
+ os.getenv("DEEPGRAM_API_KEY", ""),
+)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 00605e15dc..0be3887f82 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -92,6 +92,7 @@ log_sources = [
"RAG",
"WEBHOOK",
"SOCKET",
+ "OAUTH",
]
SRC_LOG_LEVELS = {}
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 863f58dea5..1de0348c32 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -97,6 +97,16 @@ from open_webui.config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
+ # Direct Connections
+ ENABLE_DIRECT_CONNECTIONS,
+ # Code Interpreter
+ ENABLE_CODE_INTERPRETER,
+ CODE_INTERPRETER_ENGINE,
+ CODE_INTERPRETER_PROMPT_TEMPLATE,
+ CODE_INTERPRETER_JUPYTER_URL,
+ CODE_INTERPRETER_JUPYTER_AUTH,
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@@ -130,6 +140,7 @@ from open_webui.config import (
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
WHISPER_MODEL,
+ DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
# Retrieval
@@ -180,6 +191,7 @@ from open_webui.config import (
EXA_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
+ BOCHA_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
@@ -322,7 +334,11 @@ class SPAStaticFiles(StaticFiles):
return await super().get_response(path, scope)
except (HTTPException, StarletteHTTPException) as ex:
if ex.status_code == 404:
- return await super().get_response("index.html", scope)
+ if path.endswith(".js"):
+ # Return 404 for javascript files
+ raise ex
+ else:
+ return await super().get_response("index.html", scope)
else:
raise ex
@@ -389,6 +405,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.OPENAI_MODELS = {}
+########################################
+#
+# DIRECT CONNECTIONS
+#
+########################################
+
+app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
+
########################################
#
# WEBUI
@@ -514,6 +538,7 @@ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
+app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
@@ -569,6 +594,24 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
+########################################
+#
+# CODE INTERPRETER
+#
+########################################
+
+app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
+app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
+app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
+
+app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
+)
+app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
+)
########################################
#
@@ -611,6 +654,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
+app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
@@ -753,6 +797,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
+
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
@@ -1011,14 +1056,17 @@ async def get_app_config(request: Request):
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
**(
{
+ "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
- "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
+ "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
+ "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
+ "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
}
if user is not None
else {}
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 73ff6c102d..9e0a5865e9 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -470,7 +470,7 @@ class ChatTable:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
- # we check if the chat is still shared by checkng if a chat with the share_id exists
+ # we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py
index 5c196281f7..605299528d 100644
--- a/backend/open_webui/models/users.py
+++ b/backend/open_webui/models/users.py
@@ -271,6 +271,24 @@ class UsersTable:
except Exception:
return None
+ def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
+ try:
+ with get_db() as db:
+ user_settings = db.query(User).filter_by(id=id).first().settings
+
+ if user_settings is None:
+ user_settings = {}
+
+ user_settings.update(updated)
+
+ db.query(User).filter_by(id=id).update({"settings": user_settings})
+ db.commit()
+
+ user = db.query(User).filter_by(id=id).first()
+ return UserModel.model_validate(user)
+ except Exception:
+ return None
+
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index b3d8b5eb8a..b8186b3f93 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -113,6 +113,34 @@ class OpenSearchClient:
return self._result_to_search_result(result)
+ def query(
+ self, collection_name: str, filter: dict, limit: Optional[int] = None
+ ) -> Optional[GetResult]:
+ if not self.has_collection(collection_name):
+ return None
+
+ query_body = {
+ "query": {"bool": {"filter": []}},
+ "_source": ["text", "metadata"],
+ }
+
+ for field, value in filter.items():
+ query_body["query"]["bool"]["filter"].append({"term": {field: value}})
+
+ size = limit if limit else 10
+
+ try:
+ result = self.client.search(
+ index=f"{self.index_prefix}_{collection_name}",
+ body=query_body,
+ size=size,
+ )
+
+ return self._result_to_get_result(result)
+
+ except Exception as e:
+ return None
+
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py
new file mode 100644
index 0000000000..98bdae704f
--- /dev/null
+++ b/backend/open_webui/retrieval/web/bocha.py
@@ -0,0 +1,72 @@
+import logging
+from typing import Optional
+
+import requests
+import json
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+def _parse_response(response):
+ result = {}
+ if "data" in response:
+ data = response["data"]
+ if "webPages" in data:
+ webPages = data["webPages"]
+ if "value" in webPages:
+ result["webpage"] = [
+ {
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "url": item.get("url", ""),
+ "snippet": item.get("snippet", ""),
+ "summary": item.get("summary", ""),
+ "siteName": item.get("siteName", ""),
+ "siteIcon": item.get("siteIcon", ""),
+ "datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
+ }
+ for item in webPages["value"]
+ ]
+ return result
+
+
+def search_bocha(
+ api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
+) -> list[SearchResult]:
+ """Search using Bocha's Search API and return the results as a list of SearchResult objects.
+
+ Args:
+ api_key (str): A Bocha Search API key
+ query (str): The query to search for
+ """
+ url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json"
+ }
+
+ payload = json.dumps({
+ "query": query,
+ "summary": True,
+ "freshness": "noLimit",
+ "count": count
+ })
+
+ response = requests.post(url, headers=headers, data=payload, timeout=5)
+ response.raise_for_status()
+ results = _parse_response(response.json())
+ print(results)
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
+
+ return [
+ SearchResult(
+ link=result["url"],
+ title=result.get("name"),
+ snippet=result.get("summary")
+ )
+ for result in results.get("webpage", [])[:count]
+ ]
+
diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py
index 2c51dd3c99..ea82252620 100644
--- a/backend/open_webui/retrieval/web/google_pse.py
+++ b/backend/open_webui/retrieval/web/google_pse.py
@@ -8,7 +8,6 @@ from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
-
def search_google_pse(
api_key: str,
search_engine_id: str,
@@ -17,34 +16,51 @@ def search_google_pse(
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
+ Handles pagination for counts greater than 10.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
+ count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
+ filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
+
+ Returns:
+ list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
-
headers = {"Content-Type": "application/json"}
- params = {
- "cx": search_engine_id,
- "q": query,
- "key": api_key,
- "num": count,
- }
+ all_results = []
+ start_index = 1 # Google PSE start parameter is 1-based
- response = requests.request("GET", url, headers=headers, params=params)
- response.raise_for_status()
+ while count > 0:
+ num_results_this_page = min(count, 10) # Google PSE max results per page is 10
+ params = {
+ "cx": search_engine_id,
+ "q": query,
+ "key": api_key,
+ "num": num_results_this_page,
+ "start": start_index,
+ }
+ response = requests.request("GET", url, headers=headers, params=params)
+ response.raise_for_status()
+ json_response = response.json()
+ results = json_response.get("items", [])
+ if results: # check if results are returned. If not, no more pages to fetch.
+ all_results.extend(results)
+ count -= len(results) # Decrement count by the number of results fetched in this page.
+ start_index += 10 # Increment start index for the next page
+ else:
+ break # No more results from Google PSE, break the loop
- json_response = response.json()
- results = json_response.get("items", [])
if filter_list:
- results = get_filtered_results(results, filter_list)
+ all_results = get_filtered_results(all_results, filter_list)
+
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
- for result in results
+ for result in all_results
]
diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py
index 3de6c18077..a87293db5c 100644
--- a/backend/open_webui/retrieval/web/jina_search.py
+++ b/backend/open_webui/retrieval/web/jina_search.py
@@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
- headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
- url = str(URL(jina_search_endpoint + query))
- response = requests.get(url, headers=headers)
+
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": api_key,
+ "X-Retain-Images": "none",
+ }
+
+ payload = {"q": query, "count": count if count <= 10 else 10}
+
+ url = str(URL(jina_search_endpoint))
+ response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
- for result in data["data"][:count]:
+ for result in data["data"]:
results.append(
SearchResult(
link=result["url"],
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index c1b15772bd..e2d05ba908 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
+import mimetypes
from fastapi import (
Depends,
@@ -138,6 +139,7 @@ class STTConfigForm(BaseModel):
ENGINE: str
MODEL: str
WHISPER_MODEL: str
+ DEEPGRAM_API_KEY: str
class AudioConfigUpdateForm(BaseModel):
@@ -165,6 +167,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+ "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -190,6 +193,7 @@ async def update_audio_config(
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+ request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -214,6 +218,7 @@ async def update_audio_config(
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+ "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -521,6 +526,69 @@ def transcribe(request: Request, file_path):
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+ elif request.app.state.config.STT_ENGINE == "deepgram":
+ try:
+ # Determine the MIME type of the file
+ mime, _ = mimetypes.guess_type(file_path)
+ if not mime:
+ mime = "audio/wav" # fallback to wav if undetectable
+
+ # Read the audio file
+ with open(file_path, "rb") as f:
+ file_data = f.read()
+
+ # Build headers and parameters
+ headers = {
+ "Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
+ "Content-Type": mime,
+ }
+
+ # Add model if specified
+ params = {}
+ if request.app.state.config.STT_MODEL:
+ params["model"] = request.app.state.config.STT_MODEL
+
+ # Make request to Deepgram API
+ r = requests.post(
+ "https://api.deepgram.com/v1/listen",
+ headers=headers,
+ params=params,
+ data=file_data,
+ )
+ r.raise_for_status()
+ response_data = r.json()
+
+ # Extract transcript from Deepgram response
+ try:
+ transcript = response_data["results"]["channels"][0]["alternatives"][
+ 0
+ ].get("transcript", "")
+ except (KeyError, IndexError) as e:
+ log.error(f"Malformed response from Deepgram: {str(e)}")
+ raise Exception(
+ "Failed to parse Deepgram response - unexpected response format"
+ )
+ data = {"text": transcript.strip()}
+
+ # Save transcript
+ transcript_file = f"{file_dir}/{id}.json"
+ with open(transcript_file, "w") as f:
+ json.dump(data, f)
+
+ return data
+
+ except Exception as e:
+ log.exception(e)
+ detail = None
+ if r is not None:
+ try:
+ res = r.json()
+ if "error" in res:
+ detail = f"External: {res['error'].get('message', '')}"
+ except Exception:
+ detail = f"External: {e}"
+ raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index ef6c4d8c1f..016075234a 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -36,6 +36,98 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
+############################
+# Direct Connections Config
+############################
+
+
+class DirectConnectionsConfigForm(BaseModel):
+ ENABLE_DIRECT_CONNECTIONS: bool
+
+
+@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
+async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+ return {
+ "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ }
+
+
+@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
+async def set_direct_connections_config(
+ request: Request,
+ form_data: DirectConnectionsConfigForm,
+ user=Depends(get_admin_user),
+):
+ request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
+ form_data.ENABLE_DIRECT_CONNECTIONS
+ )
+ return {
+ "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ }
+
+
+############################
+# CodeInterpreterConfig
+############################
+class CodeInterpreterConfigForm(BaseModel):
+ ENABLE_CODE_INTERPRETER: bool
+ CODE_INTERPRETER_ENGINE: str
+ CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
+ CODE_INTERPRETER_JUPYTER_URL: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
+ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
+
+
+@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
+async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
+ return {
+ "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
+ "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
+ "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
+ "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
+ "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
+ }
+
+
+@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
+async def set_code_interpreter_config(
+ request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
+):
+ request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
+ request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
+ request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
+ form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
+ form_data.CODE_INTERPRETER_JUPYTER_URL
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH
+ )
+
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
+ )
+ request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
+ form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
+ )
+
+ return {
+ "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
+ "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
+ "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
+ "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
+ "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
+ "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
+ "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
+ }
+
+
############################
# SetDefaultModels
############################
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index 7160c2e86e..0513212571 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -3,30 +3,22 @@ import os
import uuid
from pathlib import Path
from typing import Optional
-from pydantic import BaseModel
-import mimetypes
from urllib.parse import quote
-from open_webui.storage.provider import Storage
-
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
+from fastapi.responses import FileResponse, StreamingResponse
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
-from open_webui.routers.retrieval import process_file, ProcessFileForm
-
-from open_webui.config import UPLOAD_DIR
-from open_webui.env import SRC_LOG_LEVELS
-from open_webui.constants import ERROR_MESSAGES
-
-
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
-from fastapi.responses import FileResponse, StreamingResponse
-
-
+from open_webui.routers.retrieval import ProcessFileForm, process_file
+from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
+from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -41,7 +33,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(
- request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
+ request: Request,
+ file: UploadFile = File(...),
+ user=Depends(get_verified_user),
+ file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
try:
@@ -65,6 +60,7 @@ def upload_file(
"name": name,
"content_type": file.content_type,
"size": len(contents),
+ "data": file_metadata,
},
}
),
@@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files()
except Exception as e:
log.exception(e)
- log.error(f"Error deleting files")
+ log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
- log.error(f"Error getting file content")
+ log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
- log.error(f"Error getting file content")
+ log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
- log.error(f"Error deleting files")
+ log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 7afd9d106d..4046773dea 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -1,32 +1,26 @@
import asyncio
import base64
+import io
import json
import logging
import mimetypes
import re
-import uuid
from pathlib import Path
from typing import Optional
import requests
-
-
-from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-
-
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
-
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
+from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
-
+from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
-
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
@@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None
-def save_b64_image(b64_str):
+def load_b64_image_data(b64_str):
try:
- image_id = str(uuid.uuid4())
-
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
-
img_data = base64.b64decode(encoded)
- image_format = mimetypes.guess_extension(mime_type)
-
- image_filename = f"{image_id}{image_format}"
- file_path = IMAGE_CACHE_DIR / f"{image_filename}"
- with open(file_path, "wb") as f:
- f.write(img_data)
- return image_filename
else:
- image_filename = f"{image_id}.png"
- file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
-
+ mime_type = "image/png"
img_data = base64.b64decode(b64_str)
-
- # Write the image data to a file
- with open(file_path, "wb") as f:
- f.write(img_data)
- return image_filename
-
+ return img_data, mime_type
except Exception as e:
- log.exception(f"Error saving image: {e}")
+ log.exception(f"Error loading image data: {e}")
return None
-def save_url_image(url, headers=None):
- image_id = str(uuid.uuid4())
+def load_url_image_data(url, headers=None):
try:
if headers:
r = requests.get(url, headers=headers)
@@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
- image_format = mimetypes.guess_extension(mime_type)
-
- if not image_format:
- raise ValueError("Could not determine image type from MIME type")
-
- image_filename = f"{image_id}{image_format}"
-
- file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
- with open(file_path, "wb") as image_file:
- for chunk in r.iter_content(chunk_size=8192):
- image_file.write(chunk)
- return image_filename
+ return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
@@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
return None
+def upload_image(request, image_metadata, image_data, content_type, user):
+ image_format = mimetypes.guess_extension(content_type)
+ file = UploadFile(
+ file=io.BytesIO(image_data),
+ filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
+ headers={
+ "content-type": content_type,
+ },
+ )
+ file_item = upload_file(request, file, user, file_metadata=image_metadata)
+ url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+ return url
+
+
@router.post("/generations")
async def image_generations(
request: Request,
@@ -500,13 +478,9 @@ async def image_generations(
images = []
for image in res["data"]:
- image_filename = save_b64_image(image["b64_json"])
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump(data, f)
-
+ image_data, content_type = load_b64_image_data(image["b64_json"])
+ url = upload_image(request, data, image_data, content_type, user)
+ images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -552,14 +526,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
- image_filename = save_url_image(image["url"], headers)
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump(form_data.model_dump(exclude_none=True), f)
-
- log.debug(f"images: {images}")
+ image_data, content_type = load_url_image_data(image["url"], headers)
+ url = upload_image(
+ request,
+ form_data.model_dump(exclude_none=True),
+ image_data,
+ content_type,
+ user,
+ )
+ images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@@ -604,13 +579,15 @@ async def image_generations(
images = []
for image in res["images"]:
- image_filename = save_b64_image(image)
- images.append({"url": f"/cache/image/generations/{image_filename}"})
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
-
- with open(file_body_path, "w") as f:
- json.dump({**data, "info": res["info"]}, f)
-
+ image_data, content_type = load_b64_image_data(image)
+ url = upload_image(
+ request,
+ {**data, "info": res["info"]},
+ image_data,
+ content_type,
+ user,
+ )
+ images.append({"url": url})
return images
except Exception as e:
error = e
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 2ab06eb95e..64373c616c 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -11,10 +11,8 @@ import re
import time
from typing import Optional, Union
from urllib.parse import urlparse
-
import aiohttp
from aiocache import cached
-
import requests
from fastapi import (
@@ -990,6 +988,8 @@ async def generate_chat_completion(
)
payload = {**form_data.model_dump(exclude_none=True)}
+ if "metadata" in payload:
+ del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
@@ -1408,9 +1408,10 @@ async def download_model(
return None
+# TODO: Progress bar does not reflect size & duration of upload.
@router.post("/models/upload")
@router.post("/models/upload/{url_idx}")
-def upload_model(
+async def upload_model(
request: Request,
file: UploadFile = File(...),
url_idx: Optional[int] = None,
@@ -1419,59 +1420,85 @@ def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
- file_path = f"{UPLOAD_DIR}/{file.filename}"
+ # --- P1: save file locally ---
+ chunk_size = 1024 * 1024 * 2 # 2 MB chunks
+ with open(file_path, "wb") as out_f:
+ while True:
+ chunk = file.file.read(chunk_size)
+ # log.info(f"Chunk: {str(chunk)}") # DEBUG
+ if not chunk:
+ break
+ out_f.write(chunk)
- # Save file in chunks
- with open(file_path, "wb+") as f:
- for chunk in file.file:
- f.write(chunk)
-
- def file_process_stream():
+ async def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
- chunk_size = 1024 * 1024
+ log.info(f"Total Model Size: {str(total_size)}") # DEBUG
+
+ # --- P2: SSE progress + calculate sha256 hash ---
+ file_hash = calculate_sha256(file_path, chunk_size)
+ log.info(f"Model Hash: {str(file_hash)}") # DEBUG
try:
with open(file_path, "rb") as f:
- total = 0
- done = False
-
- while not done:
- chunk = f.read(chunk_size)
- if not chunk:
- done = True
- continue
-
- total += len(chunk)
- progress = round((total / total_size) * 100, 2)
-
- res = {
+ bytes_read = 0
+ while chunk := f.read(chunk_size):
+ bytes_read += len(chunk)
+ progress = round(bytes_read / total_size * 100, 2)
+ data_msg = {
"progress": progress,
"total": total_size,
- "completed": total,
+ "completed": bytes_read,
}
- yield f"data: {json.dumps(res)}\n\n"
+ yield f"data: {json.dumps(data_msg)}\n\n"
- if done:
- f.seek(0)
- hashed = calculate_sha256(f)
- f.seek(0)
+ # --- P3: Upload to ollama /api/blobs ---
+ with open(file_path, "rb") as f:
+ url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
+ response = requests.post(url, data=f)
- url = f"{ollama_url}/api/blobs/sha256:{hashed}"
- response = requests.post(url, data=f)
+ if response.ok:
+ log.info(f"Uploaded to /api/blobs") # DEBUG
+ # Remove local file
+ os.remove(file_path)
- if response.ok:
- res = {
- "done": done,
- "blob": f"sha256:{hashed}",
- "name": file.filename,
- }
- os.remove(file_path)
- yield f"data: {json.dumps(res)}\n\n"
- else:
- raise Exception(
- "Ollama: Could not create blob, Please try again."
- )
+ # Create model in ollama
+ model_name, ext = os.path.splitext(file.filename)
+ log.info(f"Created Model: {model_name}") # DEBUG
+
+ create_payload = {
+ "model": model_name,
+ # Reference the file by its original name => the uploaded blob's digest
+ "files": {file.filename: f"sha256:{file_hash}"},
+ }
+ log.info(f"Model Payload: {create_payload}") # DEBUG
+
+ # Call ollama /api/create
+ # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
+ create_resp = requests.post(
+ url=f"{ollama_url}/api/create",
+ headers={"Content-Type": "application/json"},
+ data=json.dumps(create_payload),
+ )
+
+ if create_resp.ok:
+ log.info(f"API SUCCESS!") # DEBUG
+ done_msg = {
+ "done": True,
+ "blob": f"sha256:{file_hash}",
+ "name": file.filename,
+ "model_created": model_name,
+ }
+ yield f"data: {json.dumps(done_msg)}\n\n"
+ else:
+ raise Exception(
+ f"Failed to create model in Ollama. {create_resp.text}"
+ )
+
+ else:
+ raise Exception("Ollama: Could not create blob, Please try again.")
except Exception as e:
res = {"error": str(e)}
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index d18f2a8ffc..afda362373 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -75,9 +75,9 @@ async def cleanup_response(
await session.close()
-def openai_o1_handler(payload):
+def openai_o1_o3_handler(payload):
"""
- Handle O1 specific parameters
+ Handle o1, o3 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
@@ -621,10 +621,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
- # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
- is_o1 = payload["model"].lower().startswith("o1-")
- if is_o1:
- payload = openai_o1_handler(payload)
+ # Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
+ is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
+ if is_o1_o3:
+ payload = openai_o1_o3_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 77f04a4be5..e4bab52898 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -45,6 +45,7 @@ from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
+from open_webui.retrieval.web.bocha import search_bocha
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
from open_webui.retrieval.web.google_pse import search_google_pse
from open_webui.retrieval.web.jina_search import search_jina
@@ -379,6 +380,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
+ "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -392,6 +394,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+ "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@@ -428,6 +431,7 @@ class WebSearchConfig(BaseModel):
brave_search_api_key: Optional[str] = None
kagi_search_api_key: Optional[str] = None
mojeek_search_api_key: Optional[str] = None
+ bocha_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
@@ -441,6 +445,7 @@ class WebSearchConfig(BaseModel):
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
+ domain_filter_list: Optional[List[str]] = []
class WebConfig(BaseModel):
@@ -523,6 +528,9 @@ async def update_rag_config(
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
form_data.web.search.mojeek_search_api_key
)
+ request.app.state.config.BOCHA_SEARCH_API_KEY = (
+ form_data.web.search.bocha_search_api_key
+ )
request.app.state.config.SERPSTACK_API_KEY = (
form_data.web.search.serpstack_api_key
)
@@ -553,6 +561,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
+ request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
+ form_data.web.search.domain_filter_list
+ )
return {
"status": True,
@@ -586,6 +597,7 @@ async def update_rag_config(
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
+ "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -599,6 +611,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+ "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@@ -1107,6 +1120,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- BRAVE_SEARCH_API_KEY
- KAGI_SEARCH_API_KEY
- MOJEEK_SEARCH_API_KEY
+ - BOCHA_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
@@ -1174,6 +1188,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
+ elif engine == "bocha":
+ if request.app.state.config.BOCHA_SEARCH_API_KEY:
+ return search_bocha(
+ request.app.state.config.BOCHA_SEARCH_API_KEY,
+ query,
+ request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if request.app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py
index ddcaef7674..872212d3ce 100644
--- a/backend/open_webui/routers/users.py
+++ b/backend/open_webui/routers/users.py
@@ -153,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
- user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
+ user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
if user:
return user.settings
else:
diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py
index 0c0a8aacfc..b03cf0a7ec 100644
--- a/backend/open_webui/storage/provider.py
+++ b/backend/open_webui/storage/provider.py
@@ -10,6 +10,7 @@ from open_webui.config import (
S3_ACCESS_KEY_ID,
S3_BUCKET_NAME,
S3_ENDPOINT_URL,
+ S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
GCS_BUCKET_NAME,
@@ -93,15 +94,17 @@ class S3StorageProvider(StorageProvider):
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
+ self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
- self.s3_client.upload_file(file_path, self.bucket_name, filename)
+ s3_key = os.path.join(self.key_prefix, filename)
+ self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
return (
open(file_path, "rb").read(),
- "s3://" + self.bucket_name + "/" + filename,
+ "s3://" + self.bucket_name + "/" + s3_key,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
@@ -109,18 +112,18 @@ class S3StorageProvider(StorageProvider):
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
- bucket_name, key = file_path.split("//")[1].split("/")
- local_file_path = f"{UPLOAD_DIR}/{key}"
- self.s3_client.download_file(bucket_name, key, local_file_path)
+ s3_key = self._extract_s3_key(file_path)
+ local_file_path = self._get_local_file_path(s3_key)
+ self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from S3 storage."""
- filename = file_path.split("/")[-1]
try:
- self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+ s3_key = self._extract_s3_key(file_path)
+ self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
@@ -133,6 +136,10 @@ class S3StorageProvider(StorageProvider):
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
+ # Skip objects that were not uploaded from open-webui in the first place
+ if not content["Key"].startswith(self.key_prefix):
+ continue
+
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
@@ -142,6 +149,13 @@ class S3StorageProvider(StorageProvider):
# Always delete from local storage
LocalStorageProvider.delete_all_files()
+ # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
+ def _extract_s3_key(self, full_file_path: str) -> str:
+ return "/".join(full_file_path.split("//")[1].split("/")[1:])
+
+ def _get_local_file_path(self, s3_key: str) -> str:
+ return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
+
class GCSStorageProvider(StorageProvider):
def __init__(self):
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 0719f6af5b..3b6d5ea042 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -44,6 +44,10 @@ from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
+from open_webui.utils.filter import (
+ get_sorted_filter_ids,
+ process_filter_functions,
+)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
@@ -177,116 +181,38 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
except Exception as e:
return Exception(f"Error: {e}")
- __event_emitter__ = get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- "user_id": user.id,
- }
- )
+ metadata = {
+ "chat_id": data["chat_id"],
+ "message_id": data["id"],
+ "session_id": data["session_id"],
+ "user_id": user.id,
+ }
- __event_call__ = get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- "user_id": user.id,
- }
- )
+ extra_params = {
+ "__event_emitter__": get_event_emitter(metadata),
+ "__event_call__": get_event_call(metadata),
+ "__user__": {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ },
+ "__metadata__": metadata,
+ "__request__": request,
+ "__model__": model,
+ }
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel to include vavles
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
-
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
-
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
-
- # Sort filter_ids by priority, using the get_priority function
- filter_ids.sort(key=get_priority)
-
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
-
- if filter_id in request.app.state.FUNCTIONS:
- function_module = request.app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- request.app.state.FUNCTIONS[filter_id] = function_module
-
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
-
- if not hasattr(function_module, "outlet"):
- continue
- try:
- outlet = function_module.outlet
-
- # Get the signature of the function
- sig = inspect.signature(outlet)
- params = {"body": data}
-
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": filter_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- "__request__": request,
- }
-
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
-
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
-
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, user.id
- )
- )
- except Exception as e:
- print(e)
-
- params = {**params, "__user__": __user__}
-
- if inspect.iscoroutinefunction(outlet):
- data = await outlet(**params)
- else:
- data = outlet(**params)
-
- except Exception as e:
- return Exception(f"Error: {e}")
-
- return data
+ try:
+ result, _ = await process_filter_functions(
+ request=request,
+ filter_ids=get_sorted_filter_ids(model),
+ filter_type="outlet",
+ form_data=data,
+ extra_params=extra_params,
+ )
+ return result
+ except Exception as e:
+ return Exception(f"Error: {e}")
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py
new file mode 100644
index 0000000000..0a74da9c77
--- /dev/null
+++ b/backend/open_webui/utils/code_interpreter.py
@@ -0,0 +1,148 @@
+import asyncio
+import json
+import uuid
+import websockets
+import requests
+from urllib.parse import urljoin
+
+
+async def execute_code_jupyter(
+ jupyter_url, code, token=None, password=None, timeout=10
+):
+ """
+ Executes Python code in a Jupyter kernel.
+ Supports authentication with a token or password.
+ :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
+ :param code: Code to execute
+ :param token: Jupyter authentication token (optional)
+ :param password: Jupyter password (optional)
+ :param timeout: WebSocket timeout in seconds (default: 10s)
+ :return: Dictionary with stdout, stderr, and result
+ - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
+ """
+ session = requests.Session() # Maintain cookies
+ headers = {} # Headers for requests
+
+ # Authenticate using password
+ if password and not token:
+ try:
+ login_url = urljoin(jupyter_url, "/login")
+ response = session.get(login_url)
+ response.raise_for_status()
+ xsrf_token = session.cookies.get("_xsrf")
+ if not xsrf_token:
+ raise ValueError("Failed to fetch _xsrf token")
+
+ login_data = {"_xsrf": xsrf_token, "password": password}
+ login_response = session.post(
+ login_url, data=login_data, cookies=session.cookies
+ )
+ login_response.raise_for_status()
+ headers["X-XSRFToken"] = xsrf_token
+ except Exception as e:
+ return {
+ "stdout": "",
+ "stderr": f"Authentication Error: {str(e)}",
+ "result": "",
+ }
+
+ # Construct API URLs with authentication token if provided
+ params = f"?token={token}" if token else ""
+ kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
+
+ try:
+ response = session.post(kernel_url, headers=headers, cookies=session.cookies)
+ response.raise_for_status()
+ kernel_id = response.json()["id"]
+
+ websocket_url = urljoin(
+ jupyter_url.replace("http", "ws"),
+ f"/api/kernels/{kernel_id}/channels{params}",
+ )
+
+ ws_headers = {}
+ if password and not token:
+ ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
+ cookies = {name: value for name, value in session.cookies.items()}
+ ws_headers["Cookie"] = "; ".join(
+ [f"{name}={value}" for name, value in cookies.items()]
+ )
+
+ async with websockets.connect(
+ websocket_url, additional_headers=ws_headers
+ ) as ws:
+ msg_id = str(uuid.uuid4())
+ execute_request = {
+ "header": {
+ "msg_id": msg_id,
+ "msg_type": "execute_request",
+ "username": "user",
+ "session": str(uuid.uuid4()),
+ "date": "",
+ "version": "5.3",
+ },
+ "parent_header": {},
+ "metadata": {},
+ "content": {
+ "code": code,
+ "silent": False,
+ "store_history": True,
+ "user_expressions": {},
+ "allow_stdin": False,
+ "stop_on_error": True,
+ },
+ "channel": "shell",
+ }
+ await ws.send(json.dumps(execute_request))
+
+ stdout, stderr, result = "", "", []
+
+ while True:
+ try:
+ message = await asyncio.wait_for(ws.recv(), timeout)
+ message_data = json.loads(message)
+ if message_data.get("parent_header", {}).get("msg_id") == msg_id:
+ msg_type = message_data.get("msg_type")
+
+ if msg_type == "stream":
+ if message_data["content"]["name"] == "stdout":
+ stdout += message_data["content"]["text"]
+ elif message_data["content"]["name"] == "stderr":
+ stderr += message_data["content"]["text"]
+
+ elif msg_type in ("execute_result", "display_data"):
+ data = message_data["content"]["data"]
+ if "image/png" in data:
+ result.append(
+ f"data:image/png;base64,{data['image/png']}"
+ )
+ elif "text/plain" in data:
+ result.append(data["text/plain"])
+
+ elif msg_type == "error":
+ stderr += "\n".join(message_data["content"]["traceback"])
+
+ elif (
+ msg_type == "status"
+ and message_data["content"]["execution_state"] == "idle"
+ ):
+ break
+
+ except asyncio.TimeoutError:
+ stderr += "\nExecution timed out."
+ break
+
+ except Exception as e:
+ return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
+
+ finally:
+ if kernel_id:
+ requests.delete(
+ f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
+ )
+
+ return {
+ "stdout": stdout.strip(),
+ "stderr": stderr.strip(),
+ "result": "\n".join(result).strip() if result else "",
+ }
diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py
new file mode 100644
index 0000000000..de51bd46e5
--- /dev/null
+++ b/backend/open_webui/utils/filter.py
@@ -0,0 +1,99 @@
+import inspect
+from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.models.functions import Functions
+
+
+def get_sorted_filter_ids(model):
+ def get_priority(function_id):
+ function = Functions.get_function_by_id(function_id)
+ if function is not None and hasattr(function, "valves"):
+ # TODO: Fix FunctionModel to include vavles
+ return (function.valves if function.valves else {}).get("priority", 0)
+ return 0
+
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+ if "info" in model and "meta" in model["info"]:
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+ filter_ids = list(set(filter_ids))
+
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type("filter", active_only=True)
+ ]
+
+ filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
+ filter_ids.sort(key=get_priority)
+ return filter_ids
+
+
+async def process_filter_functions(
+ request, filter_ids, filter_type, form_data, extra_params
+):
+ skip_files = None
+
+ for filter_id in filter_ids:
+ filter = Functions.get_function_by_id(filter_id)
+ if not filter:
+ continue
+
+ if filter_id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[filter_id]
+ else:
+ function_module, _, _ = load_function_module_by_id(filter_id)
+ request.app.state.FUNCTIONS[filter_id] = function_module
+
+ # Check if the function has a file_handler variable
+ if filter_type == "inlet" and hasattr(function_module, "file_handler"):
+ skip_files = function_module.file_handler
+
+ # Apply valves to the function
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+ valves = Functions.get_function_valves_by_id(filter_id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ # Prepare handler function
+ handler = getattr(function_module, filter_type, None)
+ if not handler:
+ continue
+
+ try:
+ # Prepare parameters
+ sig = inspect.signature(handler)
+ params = {"body": form_data} | {
+ k: v
+ for k, v in {
+ **extra_params,
+ "__id__": filter_id,
+ }.items()
+ if k in sig.parameters
+ }
+
+ # Handle user parameters
+ if "__user__" in sig.parameters:
+ if hasattr(function_module, "UserValves"):
+ try:
+ params["__user__"]["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(
+ filter_id, params["__user__"]["id"]
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ # Execute handler
+ if inspect.iscoroutinefunction(handler):
+ form_data = await handler(**params)
+ else:
+ form_data = handler(**params)
+
+ except Exception as e:
+ print(f"Error in {filter_type} handler {filter_id}: {e}")
+ raise e
+
+ # Handle file cleanup for inlet
+ if skip_files and "files" in form_data.get("metadata", {}):
+ del form_data["metadata"]["files"]
+
+ return form_data, {}
diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py
index 679fff9f64..b86c257591 100644
--- a/backend/open_webui/utils/images/comfyui.py
+++ b/backend/open_webui/utils/images/comfyui.py
@@ -161,7 +161,7 @@ async def comfyui_generate_image(
seed = (
payload.seed
if payload.seed
- else random.randint(0, 18446744073709551614)
+ else random.randint(0, 1125899906842624)
)
for node_id in node.node_ids:
workflow[node_id]["inputs"][node.key] = seed
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 06763483cb..5de8f8193c 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -68,7 +68,11 @@ from open_webui.utils.misc import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
-
+from open_webui.utils.filter import (
+ get_sorted_filter_ids,
+ process_filter_functions,
+)
+from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.tasks import create_task
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
-async def chat_completion_filter_functions_handler(request, body, model, extra_params):
- skip_files = None
-
- def get_filter_function_ids(model):
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
-
- filter_ids = [
- function.id for function in Functions.get_global_filter_functions()
- ]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
-
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
-
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
-
- filter_ids.sort(key=get_priority)
- return filter_ids
-
- filter_ids = get_filter_function_ids(model)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
-
- if filter_id in request.app.state.FUNCTIONS:
- function_module = request.app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- request.app.state.FUNCTIONS[filter_id] = function_module
-
- # Check if the function has a file_handler variable
- if hasattr(function_module, "file_handler"):
- skip_files = function_module.file_handler
-
- # Apply valves to the function
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
-
- if hasattr(function_module, "inlet"):
- try:
- inlet = function_module.inlet
-
- # Create a dictionary of parameters to be passed to the function
- params = {"body": body} | {
- k: v
- for k, v in {
- **extra_params,
- "__model__": model,
- "__id__": filter_id,
- }.items()
- if k in inspect.signature(inlet).parameters
- }
-
- if "__user__" in params and hasattr(function_module, "UserValves"):
- try:
- params["__user__"]["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, params["__user__"]["id"]
- )
- )
- except Exception as e:
- print(e)
-
- if inspect.iscoroutinefunction(inlet):
- body = await inlet(**params)
- else:
- body = inlet(**params)
-
- except Exception as e:
- print(f"Error: {e}")
- raise e
-
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
-
- return body, {}
-
-
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
@@ -572,13 +483,13 @@ async def chat_image_generation_handler(
{
"type": "status",
"data": {
- "description": f"An error occured while generating an image",
+ "description": f"An error occurred while generating an image",
"done": True,
},
}
)
- system_message_content = "