diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index c235aeb611..684f929134 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -2,6 +2,8 @@ import json import logging from typing import Optional import asyncio +from fastapi.responses import StreamingResponse + from open_webui.utils.misc import get_message_list from open_webui.socket.main import get_event_emitter @@ -202,6 +204,9 @@ def get_session_user_chat_usage_stats( ############################ +CHAT_EXPORT_PAGE_ITEM_COUNT = 10 + + class ChatStatsExportList(BaseModel): type: str = "chats" items: list[ChatStatsExport] @@ -209,6 +214,140 @@ class ChatStatsExportList(BaseModel): page: int +def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: + try: + def get_message_content_length(message): + content = message.get("content", "") + if isinstance(content, str): + return len(content) + elif isinstance(content, list): + return sum( + len(item.get("text", "")) + for item in content + if item.get("type") == "text" + ) + return 0 + + messages_map = chat.chat.get("history", {}).get("messages", {}) + message_id = chat.chat.get("history", {}).get("currentId") + + history_models = {} + history_message_count = len(messages_map) + history_user_messages = [] + history_assistant_messages = [] + + export_messages = {} + for key, message in messages_map.items(): + try: + content_length = get_message_content_length(message) + + # Extract rating safely + rating = message.get("annotation", {}).get("rating") + tags = message.get("annotation", {}).get("tags") + + message_stat = MessageStats( + id=message.get("id"), + role=message.get("role"), + model=message.get("model"), + timestamp=message.get("timestamp"), + content_length=content_length, + token_count=None, # Populate if available, e.g. message.get("info", {}).get("token_count") + rating=rating, + tags=tags, + ) + + export_messages[key] = message_stat + + # --- Aggregation Logic (copied/adapted from usage stats) --- + role = message.get("role", "") + if role == "user": + history_user_messages.append(message) + elif role == "assistant": + history_assistant_messages.append(message) + model = message.get("model") + if model: + if model not in history_models: + history_models[model] = 0 + history_models[model] += 1 + except Exception as e: + log.debug(f"Error processing message {key}: {e}") + continue + + # Calculate Averages + average_user_message_content_length = ( + sum(get_message_content_length(m) for m in history_user_messages) + / len(history_user_messages) + if history_user_messages + else 0 + ) + + average_assistant_message_content_length = ( + sum(get_message_content_length(m) for m in history_assistant_messages) + / len(history_assistant_messages) + if history_assistant_messages + else 0 + ) + + # Response Times + response_times = [] + for message in history_assistant_messages: + user_message_id = message.get("parentId", None) + if user_message_id and user_message_id in messages_map: + user_message = messages_map[user_message_id] + # Ensure timestamps exist + t1 = message.get("timestamp") + t0 = user_message.get("timestamp") + if t1 and t0: + response_times.append(t1 - t0) + + average_response_time = ( + sum(response_times) / len(response_times) if response_times else 0 + ) + + # Current Message List Logic (Main path) + message_list = get_message_list(messages_map, message_id) + message_count = len(message_list) + models = {} + for message in reversed(message_list): + if message.get("role") == "assistant": + model = message.get("model") + if model: + if model not in models: + models[model] = 0 + models[model] += 1 + + # Construct Aggregate Stats + stats = AggregateChatStats( + average_response_time=average_response_time, + average_user_message_content_length=average_user_message_content_length, + average_assistant_message_content_length=average_assistant_message_content_length, + models=models, + message_count=message_count, + history_models=history_models, + history_message_count=history_message_count, + history_user_message_count=len(history_user_messages), + history_assistant_message_count=len(history_assistant_messages), + ) + + # Construct Chat Body + chat_body = ChatBody( + history=ChatHistoryStats(messages=export_messages, currentId=message_id) + ) + + return ChatStatsExport( + id=chat.id, + user_id=chat.user_id, + created_at=chat.created_at, + updated_at=chat.updated_at, + tags=chat.meta.get("tags", []), + stats=stats, + chat=chat_body, + ) + except Exception as e: + log.exception(f"Error exporting stats for chat {chat.id}: {e}") + return None + + def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): if filter is None: filter = {} @@ -221,145 +360,37 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): ) chat_stats_export_list = [] - - def get_message_content_length(message): - content = message.get("content", "") - if isinstance(content, str): - return len(content) - elif isinstance(content, list): - return sum( - len(item.get("text", "")) - for item in content - if item.get("type") == "text" - ) - return 0 - for chat in result.items: - try: - messages_map = chat.chat.get("history", {}).get("messages", {}) - message_id = chat.chat.get("history", {}).get("currentId") - - history_models = {} - history_message_count = len(messages_map) - history_user_messages = [] - history_assistant_messages = [] - - export_messages = {} - for key, message in messages_map.items(): - try: - content_length = get_message_content_length(message) - - # Extract rating safely - rating = message.get("annotation", {}).get("rating") - tags = message.get("annotation", {}).get("tags") - - message_stat = MessageStats( - id=message.get("id"), - role=message.get("role"), - model=message.get("model"), - timestamp=message.get("timestamp"), - content_length=content_length, - token_count=None, # Populate if available, e.g. message.get("info", {}).get("token_count") - rating=rating, - tags=tags, - ) - - export_messages[key] = message_stat - - # --- Aggregation Logic (copied/adapted from usage stats) --- - role = message.get("role", "") - if role == "user": - history_user_messages.append(message) - elif role == "assistant": - history_assistant_messages.append(message) - model = message.get("model") - if model: - if model not in history_models: - history_models[model] = 0 - history_models[model] += 1 - except Exception as e: - log.debug(f"Error processing message {key}: {e}") - continue - - # Calculate Averages - average_user_message_content_length = ( - sum(get_message_content_length(m) for m in history_user_messages) - / len(history_user_messages) - if history_user_messages - else 0 - ) - - average_assistant_message_content_length = ( - sum(get_message_content_length(m) for m in history_assistant_messages) - / len(history_assistant_messages) - if history_assistant_messages - else 0 - ) - - # Response Times - response_times = [] - for message in history_assistant_messages: - user_message_id = message.get("parentId", None) - if user_message_id and user_message_id in messages_map: - user_message = messages_map[user_message_id] - # Ensure timestamps exist - t1 = message.get("timestamp") - t0 = user_message.get("timestamp") - if t1 and t0: - response_times.append(t1 - t0) - - average_response_time = ( - sum(response_times) / len(response_times) if response_times else 0 - ) - - # Current Message List Logic (Main path) - message_list = get_message_list(messages_map, message_id) - message_count = len(message_list) - models = {} - for message in reversed(message_list): - if message.get("role") == "assistant": - model = message.get("model") - if model: - if model not in models: - models[model] = 0 - models[model] += 1 - - # Construct Aggregate Stats - stats = AggregateChatStats( - average_response_time=average_response_time, - average_user_message_content_length=average_user_message_content_length, - average_assistant_message_content_length=average_assistant_message_content_length, - models=models, - message_count=message_count, - history_models=history_models, - history_message_count=history_message_count, - history_user_message_count=len(history_user_messages), - history_assistant_message_count=len(history_assistant_messages), - ) - - # Construct Chat Body - chat_body = ChatBody( - history=ChatHistoryStats(messages=export_messages, currentId=message_id) - ) - - chat_stat = ChatStatsExport( - id=chat.id, - user_id=chat.user_id, - created_at=chat.created_at, - updated_at=chat.updated_at, - tags=chat.meta.get("tags", []), - stats=stats, - chat=chat_body, - ) - + chat_stat = _process_chat_for_export(chat) + if chat_stat: chat_stats_export_list.append(chat_stat) - except Exception as e: - log.debug(f"Error exporting stats for chat {chat.id}: {e}") - continue return chat_stats_export_list, result.total +async def generate_chat_stats_jsonl_generator(user_id, filter): + skip = 0 + limit = CHAT_EXPORT_PAGE_ITEM_COUNT + + while True: + # Use asyncio.to_thread to make the blocking DB call non-blocking + result = await asyncio.to_thread( + Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit + ) + if not result.items: + break + + for chat in result.items: + try: + chat_stat = await asyncio.to_thread(_process_chat_for_export, chat) + if chat_stat: + yield chat_stat.model_dump_json() + "\n" + except Exception as e: + log.exception(f"Error processing chat {chat.id}: {e}") + + skip += limit + + @router.get("/stats/export", response_model=ChatStatsExportList) async def export_chat_stats( request: Request, @@ -367,6 +398,7 @@ async def export_chat_stats( start_time: Optional[int] = None, end_time: Optional[int] = None, page: Optional[int] = 1, + stream: bool = False, user=Depends(get_verified_user), ): # Check if the user has permission to share/export chats @@ -379,9 +411,6 @@ async def export_chat_stats( ) try: - limit = 10 # Fixed limit for export - skip = (page - 1) * limit - # Fetch chats with date filtering filter = {"order_by": "created_at", "direction": "asc"} @@ -395,11 +424,25 @@ async def export_chat_stats( if end_time: filter["end_time"] = end_time - chat_stats_export_list, total = await asyncio.to_thread( - calculate_chat_stats, user.id, skip, limit, filter - ) + if stream: + return StreamingResponse( + generate_chat_stats_jsonl_generator(user.id, filter), + media_type="application/x-ndjson", + headers={ + "Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl" + }, + ) + else: + limit = CHAT_EXPORT_PAGE_ITEM_COUNT + skip = (page - 1) * limit - return ChatStatsExportList(items=chat_stats_export_list, total=total, page=page) + chat_stats_export_list, total = await asyncio.to_thread( + calculate_chat_stats, user.id, skip, limit, filter + ) + + return ChatStatsExportList( + items=chat_stats_export_list, total=total, page=page + ) except Exception as e: log.debug(f"Error exporting chat stats: {e}") diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 435ffda800..9915744365 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -1206,4 +1206,36 @@ export const exportChatStats = async (token: string, page: number = 1, params: o return res; }; +export const downloadChatStats = async ( + token: string = '', + chat_id: string | null = null, + start_time: number | null = null, + end_time: number | null = null +): Promise<[Response | null, AbortController]> => { + const controller = new AbortController(); + let error = null; + let url = `${WEBUI_API_BASE_URL}/chats/stats/export?stream=true`; + if (chat_id) url += `&chat_id=${chat_id}`; + if (start_time) url += `&start_time=${start_time}`; + if (end_time) url += `&end_time=${end_time}`; + + const res = await fetch(url, { + signal: controller.signal, + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }).catch((err) => { + console.error(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return [res, controller]; +}; diff --git a/src/lib/components/chat/Settings/SyncStatsModal.svelte b/src/lib/components/chat/Settings/SyncStatsModal.svelte index b886f397ba..1a05cb8ebb 100644 --- a/src/lib/components/chat/Settings/SyncStatsModal.svelte +++ b/src/lib/components/chat/Settings/SyncStatsModal.svelte @@ -4,7 +4,7 @@ import { toast } from 'svelte-sonner'; import { onMount, getContext } from 'svelte'; - import { exportChatStats } from '$lib/apis/chats'; + import { exportChatStats, downloadChatStats } from '$lib/apis/chats'; import { getVersion } from '$lib/apis'; import Check from '$lib/components/icons/Check.svelte'; @@ -18,7 +18,19 @@ export let show = false; export let eventData = null; - let loading = false; + let syncing = false; + let downloading = false; + let downloadController = null; + + const cancelDownload = () => { + if (downloadController) { + downloadController.abort(); + downloading = false; + syncing = false; + downloadController = null; + } + }; + let completed = false; let processedItemsCount = 0; let total = 0; @@ -42,7 +54,7 @@ ); } - loading = true; + syncing = true; processedItemsCount = 0; total = 0; let page = 1; @@ -89,9 +101,106 @@ '*' ); } - loading = false; + syncing = false; completed = true; }; + + const downloadHandler = async () => { + if (downloading) { + cancelDownload(); + return; + } + + // Get total count first + const _res = await exportChatStats(localStorage.token, 1, eventData?.searchParams ?? {}).catch(() => { + return null; + }); + if (_res) { + total = _res.total; + } + + downloading = true; + syncing = true; + processedItemsCount = 0; + const resVersion = await getVersion(localStorage.token).catch(() => { + return null; + }); + + const version = resVersion ? resVersion.version : '0.0.0'; + const filename = `open-webui-stats-${version}-${Date.now()}.json`; + + const searchParams = eventData?.searchParams ?? {}; + const [res, controller] = await downloadChatStats( + localStorage.token, + searchParams.chat_id, + searchParams.start_time, + searchParams.end_time + ).catch((error) => { + toast.error(error?.detail || $i18n.t('An error occurred while downloading your stats.')); + return [null, null]; + }); + + if (res) { + downloadController = controller; + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + + const items = []; + let buffer = ''; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.trim() !== '') { + try { + items.push(JSON.parse(line)); + processedItemsCount++; + } catch (e) { + console.error('Error parsing line', e); + } + } + } + } + + if (buffer.trim() !== '') { + try { + items.push(JSON.parse(buffer)); + processedItemsCount++; + } catch (e) { + console.error('Error parsing buffer', e); + } + } + + if (downloading) { + const blob = new Blob([JSON.stringify(items)], { type: 'application/json' }); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + window.URL.revokeObjectURL(url); + } + } catch (e) { + console.error('Download error:', e); + } finally { + downloading = false; + syncing = false; + downloadController = null; + } + } else { + downloading = false; + syncing = false; + downloadController = null; + } + }; @@ -127,7 +236,7 @@ on:click={() => { show = false; }} - disabled={loading} + disabled={syncing} > @@ -166,10 +275,12 @@ - {#if loading} -
+ {#if syncing} +
-
{$i18n.t('Syncing stats...')}
+
+ {downloading ? $i18n.t('Downloading stats...') : $i18n.t('Syncing stats...')} +
{Math.round((processedItemsCount / total) * 100) || 0}%
@@ -181,23 +292,51 @@
{/if} -
+
+
+ +
+ -
diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index c7bb8c5c50..ee2e5c16e1 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -91,7 +91,7 @@ let showRefresh = false; - let showSyncStatsModal = false; + let showSyncStatsModal = true; let syncStatsEventData = null; let heartbeatInterval = null;