mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: async file upload
This commit is contained in:
parent
efebf4d3a0
commit
5e1f4fa0ff
6 changed files with 211 additions and 55 deletions
|
|
@ -6,8 +6,10 @@ from fnmatch import fnmatch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
BackgroundTasks,
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
File,
|
File,
|
||||||
|
|
@ -18,6 +20,7 @@ from fastapi import (
|
||||||
status,
|
status,
|
||||||
Query,
|
Query,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
@ -42,7 +45,6 @@ from pydantic import BaseModel
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,13 +85,64 @@ def has_access_to_file(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
def process_uploaded_file(request, file, file_item, file_metadata, user):
|
||||||
|
try:
|
||||||
|
if file.content_type:
|
||||||
|
stt_supported_content_types = getattr(
|
||||||
|
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
fnmatch(file.content_type, content_type)
|
||||||
|
for content_type in (
|
||||||
|
stt_supported_content_types
|
||||||
|
if stt_supported_content_types
|
||||||
|
and any(t.strip() for t in stt_supported_content_types)
|
||||||
|
else ["audio/*", "video/webm"]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
file_path = Storage.get_file(file_path)
|
||||||
|
result = transcribe(request, file_path, file_metadata)
|
||||||
|
|
||||||
|
process_file(
|
||||||
|
request,
|
||||||
|
ProcessFileForm(
|
||||||
|
file_id=file_item.id, content=result.get("text", "")
|
||||||
|
),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||||
|
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||||
|
):
|
||||||
|
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||||
|
else:
|
||||||
|
log.info(
|
||||||
|
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||||
|
)
|
||||||
|
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||||
|
|
||||||
|
Files.update_file_data_by_id(
|
||||||
|
file_item.id,
|
||||||
|
{"status": "completed"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error processing file: {file_item.id}")
|
||||||
|
Files.update_file_data_by_id(
|
||||||
|
file_item.id,
|
||||||
|
{
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=FileModelResponse)
|
@router.post("/", response_model=FileModelResponse)
|
||||||
def upload_file(
|
def upload_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
metadata: Optional[dict | str] = Form(None),
|
metadata: Optional[dict | str] = Form(None),
|
||||||
process: bool = Query(True),
|
process: bool = Query(True),
|
||||||
internal: bool = False,
|
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
|
|
@ -112,7 +165,7 @@ def upload_file(
|
||||||
# Remove the leading dot from the file extension
|
# Remove the leading dot from the file extension
|
||||||
file_extension = file_extension[1:] if file_extension else ""
|
file_extension = file_extension[1:] if file_extension else ""
|
||||||
|
|
||||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
if process and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||||
]
|
]
|
||||||
|
|
@ -147,6 +200,9 @@ def upload_file(
|
||||||
"id": id,
|
"id": id,
|
||||||
"filename": name,
|
"filename": name,
|
||||||
"path": file_path,
|
"path": file_path,
|
||||||
|
"data": {
|
||||||
|
**({"status": "pending"} if process else {}),
|
||||||
|
},
|
||||||
"meta": {
|
"meta": {
|
||||||
"name": name,
|
"name": name,
|
||||||
"content_type": file.content_type,
|
"content_type": file.content_type,
|
||||||
|
|
@ -156,58 +212,25 @@ def upload_file(
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if process:
|
if process:
|
||||||
try:
|
background_tasks.add_task(
|
||||||
if file.content_type:
|
process_uploaded_file,
|
||||||
stt_supported_content_types = getattr(
|
request,
|
||||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
file,
|
||||||
)
|
file_item,
|
||||||
|
file_metadata,
|
||||||
if any(
|
user,
|
||||||
fnmatch(file.content_type, content_type)
|
|
||||||
for content_type in (
|
|
||||||
stt_supported_content_types
|
|
||||||
if stt_supported_content_types
|
|
||||||
and any(t.strip() for t in stt_supported_content_types)
|
|
||||||
else ["audio/*", "video/webm"]
|
|
||||||
)
|
|
||||||
):
|
|
||||||
file_path = Storage.get_file(file_path)
|
|
||||||
result = transcribe(request, file_path, file_metadata)
|
|
||||||
|
|
||||||
process_file(
|
|
||||||
request,
|
|
||||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
|
||||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
|
||||||
):
|
|
||||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
|
||||||
else:
|
|
||||||
log.info(
|
|
||||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
|
||||||
)
|
|
||||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
|
||||||
|
|
||||||
file_item = Files.get_file_by_id(id=id)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
log.error(f"Error processing file: {file_item.id}")
|
|
||||||
file_item = FileModelResponse(
|
|
||||||
**{
|
|
||||||
**file_item.model_dump(),
|
|
||||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_item:
|
|
||||||
return file_item
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
|
||||||
)
|
)
|
||||||
|
return {"status": True, **file_item.model_dump()}
|
||||||
|
else:
|
||||||
|
if file_item:
|
||||||
|
return file_item
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -334,6 +357,60 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}/process/status")
|
||||||
|
async def get_file_process_status(
|
||||||
|
id: str, stream: bool = Query(False), user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id)
|
||||||
|
|
||||||
|
if not file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
file.user_id == user.id
|
||||||
|
or user.role == "admin"
|
||||||
|
or has_access_to_file(id, "read", user)
|
||||||
|
):
|
||||||
|
if stream:
|
||||||
|
MAX_FILE_PROCESSING_DURATION = 3600 * 2
|
||||||
|
|
||||||
|
async def event_stream(file_item):
|
||||||
|
for _ in range(MAX_FILE_PROCESSING_DURATION):
|
||||||
|
file_item = Files.get_file_by_id(file_item.id)
|
||||||
|
if file_item:
|
||||||
|
data = file_item.model_dump().get("data", {})
|
||||||
|
status = data.get("status")
|
||||||
|
|
||||||
|
if status:
|
||||||
|
event = {"status": status}
|
||||||
|
if status == "failed":
|
||||||
|
event["error"] = data.get("error")
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(event)}\n\n"
|
||||||
|
if status in ("completed", "failed"):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Legacy
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_stream(file),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return {"status": file.data.get("status", "pending")}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# Get File Data Content By Id
|
# Get File Data Content By Id
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,9 @@ def upload_image(request, image_data, content_type, metadata, user):
|
||||||
"content-type": content_type,
|
"content-type": content_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
file_item = upload_file(
|
||||||
|
request, file=file, metadata=metadata, process=False, user=user
|
||||||
|
)
|
||||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1476,7 +1476,7 @@ def process_file(
|
||||||
log.debug(f"text_content: {text_content}")
|
log.debug(f"text_content: {text_content}")
|
||||||
Files.update_file_data_by_id(
|
Files.update_file_data_by_id(
|
||||||
file.id,
|
file.id,
|
||||||
{"content": text_content},
|
{"status": "completed", "content": text_content},
|
||||||
)
|
)
|
||||||
|
|
||||||
hash = calculate_sha256_string(text_content)
|
hash = calculate_sha256_string(text_content)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
||||||
|
import { splitStream } from '$lib/utils';
|
||||||
|
|
||||||
export const uploadFile = async (token: string, file: File, metadata?: object | null) => {
|
export const uploadFile = async (token: string, file: File, metadata?: object | null) => {
|
||||||
const data = new FormData();
|
const data = new FormData();
|
||||||
|
|
@ -31,6 +32,75 @@ export const uploadFile = async (token: string, file: File, metadata?: object |
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (res) {
|
||||||
|
const status = await getFileProcessStatus(token, res.id);
|
||||||
|
|
||||||
|
if (status && status.ok) {
|
||||||
|
const reader = status.body
|
||||||
|
.pipeThrough(new TextDecoderStream())
|
||||||
|
.pipeThrough(splitStream('\n'))
|
||||||
|
.getReader();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
let lines = value.split('\n');
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line !== '') {
|
||||||
|
console.log(line);
|
||||||
|
if (line === 'data: [DONE]') {
|
||||||
|
console.log(line);
|
||||||
|
} else {
|
||||||
|
let data = JSON.parse(line.replace(/^data: /, ''));
|
||||||
|
console.log(data);
|
||||||
|
|
||||||
|
if (data?.error) {
|
||||||
|
console.error(data.error);
|
||||||
|
res.error = data.error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getFileProcessStatus = async (token: string, id: string) => {
|
||||||
|
const queryParams = new URLSearchParams();
|
||||||
|
queryParams.append('stream', 'true');
|
||||||
|
|
||||||
|
let error = null;
|
||||||
|
const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}/process/status?${queryParams}`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
authorization: `Bearer ${token}`
|
||||||
|
}
|
||||||
|
}).catch((err) => {
|
||||||
|
error = err.detail;
|
||||||
|
console.error(err);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@
|
||||||
import { fade } from 'svelte/transition';
|
import { fade } from 'svelte/transition';
|
||||||
import Tooltip from '../common/Tooltip.svelte';
|
import Tooltip from '../common/Tooltip.svelte';
|
||||||
import Sidebar from '../icons/Sidebar.svelte';
|
import Sidebar from '../icons/Sidebar.svelte';
|
||||||
|
import { uploadFile } from '$lib/apis/files';
|
||||||
|
|
||||||
export let chatIdProp = '';
|
export let chatIdProp = '';
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,12 @@
|
||||||
|
|
||||||
if (uploadedFile) {
|
if (uploadedFile) {
|
||||||
console.log(uploadedFile);
|
console.log(uploadedFile);
|
||||||
|
|
||||||
|
if (uploadedFile.error) {
|
||||||
|
console.warn('File upload warning:', uploadedFile.error);
|
||||||
|
toast.warning(uploadedFile.error);
|
||||||
|
}
|
||||||
|
|
||||||
knowledge.files = knowledge.files.map((item) => {
|
knowledge.files = knowledge.files.map((item) => {
|
||||||
if (item.itemId === tempItemId) {
|
if (item.itemId === tempItemId) {
|
||||||
item.id = uploadedFile.id;
|
item.id = uploadedFile.id;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue