From 4be99174be505b187812e0b48b9f9f7bf6897c52 Mon Sep 17 00:00:00 2001 From: Tim Baek Date: Thu, 25 Dec 2025 18:11:17 -0500 Subject: [PATCH] refac --- backend/open_webui/models/chats.py | 61 +++++++++- backend/open_webui/routers/chats.py | 180 +++++++++++++++++++++++++++- 2 files changed, 234 insertions(+), 7 deletions(-) diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index d821985a4e..4082a40d15 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -210,6 +210,47 @@ class ChatUsageStatsListResponse(BaseModel): model_config = ConfigDict(extra="allow") +class MessageStats(BaseModel): + id: str + role: str + model: Optional[str] = None + content_length: int + token_count: Optional[int] = None + timestamp: Optional[int] = None + rating: Optional[int] = None # Derived from message.annotation.rating + + +class ChatHistoryStats(BaseModel): + messages: dict[str, MessageStats] + currentId: Optional[str] = None + + +class ChatBody(BaseModel): + history: ChatHistoryStats + + +class AggregateChatStats(BaseModel): + average_response_time: float + average_user_message_content_length: float + average_assistant_message_content_length: float + models: dict[str, int] + message_count: int + history_models: dict[str, int] + history_message_count: int + history_user_message_count: int + history_assistant_message_count: int + + +class ChatStatsExport(BaseModel): + id: str + user_id: str + created_at: int + updated_at: int + tags: list[str] = [] + stats: AggregateChatStats + chat: ChatBody + + class ChatTable: def _clean_null_bytes(self, obj): """Recursively remove null bytes from strings in dict/list structures.""" @@ -750,14 +791,22 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id( - self, user_id: str, skip: Optional[int] = None, limit: Optional[int] = None + self, + user_id: str, + filter: Optional[dict] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, ) -> ChatListResponse: with get_db() as db: - query = ( - db.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) + query = db.query(Chat).filter_by(user_id=user_id) + + if filter: + if filter.get("start_time"): + query = query.filter(Chat.created_at >= filter.get("start_time")) + if filter.get("end_time"): + query = query.filter(Chat.created_at <= filter.get("end_time")) + + query = query.order_by(Chat.updated_at.desc()) total = query.count() diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 8dde946a4d..ac914c63ff 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -13,6 +13,11 @@ from open_webui.models.chats import ( ChatResponse, Chats, ChatTitleIdResponse, + ChatStatsExport, + AggregateChatStats, + ChatBody, + ChatHistoryStats, + MessageStats, ) from open_webui.models.tags import TagModel, Tags from open_webui.models.folders import Folders @@ -192,11 +197,184 @@ def get_session_user_chat_usage_stats( ) + ############################ -# DeleteAllChats +# GetChatStatsExport ############################ +@router.get("/stats/export", response_model=list[ChatStatsExport]) +async def export_chat_stats( + request: Request, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), +): + # Check if the user has permission to share/export chats + if (user.role != "admin") and ( + not request.app.state.config.ENABLE_COMMUNITY_SHARING + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + # Default pagination + limit = 50 + skip = (page - 1) * limit + + # Fetch chats with date filtering + filter = {} + if start_time: + filter["start_time"] = start_time + if end_time: + filter["end_time"] = end_time + + result = Chats.get_chats_by_user_id( + user.id, + skip=skip, + limit=limit, + filter=filter, + ) + + chat_stats_export_list = [] + + for chat in result.items: + messages_map = chat.chat.get("history", {}).get("messages", {}) + message_id = chat.chat.get("history", {}).get("currentId") + + history_models = {} + history_message_count = len(messages_map) + history_user_messages = [] + history_assistant_messages = [] + + # --- Detailed Message Stats --- + export_messages = {} + for key, message in messages_map.items(): + content = message.get("content", "") + if isinstance(content, str): + content_length = len(content) + else: + content_length = 0 # Handle cases where content might be None or not string + + # Extract rating safely + rating = message.get("annotation", {}).get("rating") + + export_messages[key] = MessageStats( + id=message.get("id"), + role=message.get("role"), + model=message.get("model"), + timestamp=message.get("timestamp"), + content_length=content_length, + token_count=None, # Populate if available, e.g. message.get("info", {}).get("token_count") + rating=rating, + ) + + # --- Aggregation Logic (copied/adapted from usage stats) --- + role = message.get("role", "") + if role == "user": + history_user_messages.append(message) + elif role == "assistant": + history_assistant_messages.append(message) + model = message.get("model") + if model: + if model not in history_models: + history_models[model] = 0 + history_models[model] += 1 + + # Calculate Averages + average_user_message_content_length = ( + sum( + len(m.get("content", "")) + for m in history_user_messages + if isinstance(m.get("content"), str) + ) + / len(history_user_messages) + if history_user_messages + else 0 + ) + + average_assistant_message_content_length = ( + sum( + len(m.get("content", "")) + for m in history_assistant_messages + if isinstance(m.get("content"), str) + ) + / len(history_assistant_messages) + if history_assistant_messages + else 0 + ) + + # Response Times + response_times = [] + for message in history_assistant_messages: + user_message_id = message.get("parentId", None) + if user_message_id and user_message_id in messages_map: + user_message = messages_map[user_message_id] + # Ensure timestamps exist + t1 = message.get("timestamp") + t0 = user_message.get("timestamp") + if t1 and t0: + response_times.append(t1 - t0) + + average_response_time = ( + sum(response_times) / len(response_times) if response_times else 0 + ) + + # Current Message List Logic (Main path) + message_list = get_message_list(messages_map, message_id) + message_count = len(message_list) + models = {} + for message in reversed(message_list): + if message.get("role") == "assistant": + model = message.get("model") + if model: + if model not in models: + models[model] = 0 + models[model] += 1 + + # Construct Aggregate Stats + stats = AggregateChatStats( + average_response_time=average_response_time, + average_user_message_content_length=average_user_message_content_length, + average_assistant_message_content_length=average_assistant_message_content_length, + models=models, + message_count=message_count, + history_models=history_models, + history_message_count=history_message_count, + history_user_message_count=len(history_user_messages), + history_assistant_message_count=len(history_assistant_messages), + ) + + # Construct Chat Body + chat_body = ChatBody( + history=ChatHistoryStats(messages=export_messages, currentId=message_id) + ) + + chat_stats_export_list.append( + ChatStatsExport( + id=chat.id, + user_id=chat.user_id, + created_at=chat.created_at, + updated_at=chat.updated_at, + tags=chat.meta.get("tags", []), + stats=stats, + chat=chat_body, + ) + ) + + return chat_stats_export_list + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + + @router.delete("/", response_model=bool) async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):