diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 727bfe65dd..127f22e103 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -66,7 +66,6 @@ from open_webui.socket.main import ( periodic_usage_pool_cleanup, get_event_emitter, get_models_in_use, - get_active_user_ids, ) from open_webui.routers import ( audio, @@ -2021,7 +2020,10 @@ async def get_current_usage(user=Depends(get_verified_user)): This is an experimental endpoint and subject to change. """ try: - return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()} + return { + "model_ids": get_models_in_use(), + "user_count": Users.get_active_user_count(), + } except Exception as e: log.error(f"Error getting usage statistics: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") diff --git a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py new file mode 100644 index 0000000000..f35a382645 --- /dev/null +++ b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py @@ -0,0 +1,251 @@ +"""Update user table + +Revision ID: b10670c03dd5 +Revises: 2f1211949ecc +Create Date: 2025-11-28 04:55:31.737538 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +import open_webui.internal.db +import json +import time + +# revision identifiers, used by Alembic. +revision: str = "b10670c03dd5" +down_revision: Union[str, None] = "2f1211949ecc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _drop_sqlite_indexes_for_column(table_name, column_name, conn): + """ + SQLite requires manual removal of any indexes referencing a column + before ALTER TABLE ... DROP COLUMN can succeed. + """ + indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall() + + for idx in indexes: + index_name = idx[1] # index name + # Get indexed columns + idx_info = conn.execute( + sa.text(f"PRAGMA index_info('{index_name}')") + ).fetchall() + + indexed_cols = [row[2] for row in idx_info] # col names + if column_name in indexed_cols: + conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}")) + + +def _convert_column_to_json(table: str, column: str): + conn = op.get_bind() + dialect = conn.dialect.name + + # SQLite cannot ALTER COLUMN → must recreate column + if dialect == "sqlite": + # 1. Add temporary column + op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True)) + + # 2. Load old data + rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() + + for row in rows: + uid, raw = row + if raw is None: + parsed = None + else: + try: + parsed = json.loads(raw) + except Exception: + parsed = None # fallback safe behavior + + conn.execute( + sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'), + {"val": json.dumps(parsed) if parsed else None, "id": uid}, + ) + + # 3. Drop old TEXT column + op.drop_column(table, column) + + # 4. Rename new JSON column → original name + op.alter_column(table, f"{column}_json", new_column_name=column) + + else: + # PostgreSQL supports direct CAST + op.alter_column( + table, + column, + type_=sa.JSON(), + postgresql_using=f"{column}::json", + ) + + +def _convert_column_to_text(table: str, column: str): + conn = op.get_bind() + dialect = conn.dialect.name + + if dialect == "sqlite": + op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True)) + + rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() + + for uid, raw in rows: + conn.execute( + sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'), + {"val": json.dumps(raw) if raw else None, "id": uid}, + ) + + op.drop_column(table, column) + op.alter_column(table, f"{column}_text", new_column_name=column) + + else: + op.alter_column( + table, + column, + type_=sa.Text(), + postgresql_using=f"to_json({column})::text", + ) + + +def upgrade() -> None: + op.add_column( + "user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True) + ) + op.add_column("user", sa.Column("timezone", sa.String(), nullable=True)) + + op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True)) + op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True)) + op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True)) + op.add_column( + "user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True) + ) + + op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True)) + + # Convert info (TEXT/JSONField) → JSON + _convert_column_to_json("user", "info") + # Convert settings (TEXT/JSONField) → JSON + _convert_column_to_json("user", "settings") + + op.create_table( + "api_key", + sa.Column("id", sa.Text(), primary_key=True, unique=True), + sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")), + sa.Column("key", sa.Text(), unique=True, nullable=False), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("expires_at", sa.BigInteger(), nullable=True), + sa.Column("last_used_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + ) + + conn = op.get_bind() + users = conn.execute( + sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL') + ).fetchall() + + for uid, oauth_sub in users: + if oauth_sub: + # Example formats supported: + # provider@sub + # plain sub (stored as {"oidc": {"sub": sub}}) + if "@" in oauth_sub: + provider, sub = oauth_sub.split("@", 1) + else: + provider, sub = "oidc", oauth_sub + + oauth_json = json.dumps({provider: {"sub": sub}}) + conn.execute( + sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'), + {"oauth": oauth_json, "id": uid}, + ) + + users_with_keys = conn.execute( + sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL') + ).fetchall() + now = int(time.time()) + + for uid, api_key in users_with_keys: + if api_key: + conn.execute( + sa.text( + """ + INSERT INTO api_key (id, user_id, key, created_at, updated_at) + VALUES (:id, :user_id, :key, :created_at, :updated_at) + """ + ), + { + "id": f"key_{uid}", + "user_id": uid, + "key": api_key, + "created_at": now, + "updated_at": now, + }, + ) + + if conn.dialect.name == "sqlite": + _drop_sqlite_indexes_for_column("user", "api_key", conn) + _drop_sqlite_indexes_for_column("user", "oauth_sub", conn) + + with op.batch_alter_table("user") as batch_op: + batch_op.drop_column("api_key") + batch_op.drop_column("oauth_sub") + + +def downgrade() -> None: + # --- 1. Restore old oauth_sub column --- + op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True)) + + conn = op.get_bind() + users = conn.execute( + sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL') + ).fetchall() + + for uid, oauth in users: + try: + data = json.loads(oauth) + provider = list(data.keys())[0] + sub = data[provider].get("sub") + oauth_sub = f"{provider}@{sub}" + except Exception: + oauth_sub = None + + conn.execute( + sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'), + {"oauth_sub": oauth_sub, "id": uid}, + ) + + op.drop_column("user", "oauth") + + # --- 2. Restore api_key field --- + op.add_column("user", sa.Column("api_key", sa.String(), nullable=True)) + + # Restore values from api_key + keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall() + for uid, key in keys: + conn.execute( + sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'), + {"key": key, "id": uid}, + ) + + # Drop new table + op.drop_table("api_key") + + with op.batch_alter_table("user") as batch_op: + batch_op.drop_column("profile_banner_image_url") + batch_op.drop_column("timezone") + + batch_op.drop_column("presence_state") + batch_op.drop_column("status_emoji") + batch_op.drop_column("status_message") + batch_op.drop_column("status_expires_at") + + # Convert info (JSON) → TEXT + _convert_column_to_text("user", "info") + # Convert settings (JSON) → TEXT + _convert_column_to_text("user", "settings") diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 0d0b881a78..8b03580e6c 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -88,7 +88,7 @@ class AuthsTable: name: str, profile_image_url: str = "/user.png", role: str = "pending", - oauth_sub: Optional[str] = None, + oauth: Optional[dict] = None, ) -> Optional[UserModel]: with get_db() as db: log.info("insert_new_auth") @@ -102,7 +102,7 @@ class AuthsTable: db.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub + id, name, email, profile_image_url, role, oauth=oauth ) db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index e7beeee1bf..ede5f5e761 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -11,7 +11,17 @@ from open_webui.utils.misc import throttle from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, Date, exists, select +from sqlalchemy import ( + BigInteger, + JSON, + Column, + String, + Boolean, + Text, + Date, + exists, + select, +) from sqlalchemy import or_, case import datetime @@ -21,59 +31,71 @@ import datetime #################### -class User(Base): - __tablename__ = "user" - - id = Column(String, primary_key=True, unique=True) - name = Column(String) - - email = Column(String) - username = Column(String(50), nullable=True) - - role = Column(String) - profile_image_url = Column(Text) - - bio = Column(Text, nullable=True) - gender = Column(Text, nullable=True) - date_of_birth = Column(Date, nullable=True) - - info = Column(JSONField, nullable=True) - settings = Column(JSONField, nullable=True) - - api_key = Column(String, nullable=True, unique=True) - oauth_sub = Column(Text, unique=True) - - last_active_at = Column(BigInteger) - - updated_at = Column(BigInteger) - created_at = Column(BigInteger) - - class UserSettings(BaseModel): ui: Optional[dict] = {} model_config = ConfigDict(extra="allow") pass +class User(Base): + __tablename__ = "user" + + id = Column(String, primary_key=True, unique=True) + email = Column(String) + username = Column(String(50), nullable=True) + role = Column(String) + + name = Column(String) + + profile_image_url = Column(Text) + profile_banner_image_url = Column(Text, nullable=True) + + bio = Column(Text, nullable=True) + gender = Column(Text, nullable=True) + date_of_birth = Column(Date, nullable=True) + timezone = Column(String, nullable=True) + + presence_state = Column(String, nullable=True) + status_emoji = Column(String, nullable=True) + status_message = Column(Text, nullable=True) + status_expires_at = Column(BigInteger, nullable=True) + + info = Column(JSON, nullable=True) + settings = Column(JSON, nullable=True) + + oauth = Column(JSON, nullable=True) + + last_active_at = Column(BigInteger) + updated_at = Column(BigInteger) + created_at = Column(BigInteger) + + class UserModel(BaseModel): id: str - name: str email: str username: Optional[str] = None - role: str = "pending" + + name: str + profile_image_url: str + profile_banner_image_url: Optional[str] = None bio: Optional[str] = None gender: Optional[str] = None date_of_birth: Optional[datetime.date] = None + timezone: Optional[str] = None + + presence_state: Optional[str] = None + status_emoji: Optional[str] = None + status_message: Optional[str] = None + status_expires_at: Optional[int] = None info: Optional[dict] = None settings: Optional[UserSettings] = None - api_key: Optional[str] = None - oauth_sub: Optional[str] = None + oauth: Optional[dict] = None last_active_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -82,6 +104,32 @@ class UserModel(BaseModel): model_config = ConfigDict(from_attributes=True) +class ApiKey(Base): + __tablename__ = "api_key" + + id = Column(Text, primary_key=True, unique=True) + user_id = Column(Text, nullable=False) + key = Column(Text, unique=True, nullable=False) + data = Column(JSON, nullable=True) + expires_at = Column(BigInteger, nullable=True) + last_used_at = Column(BigInteger, nullable=True) + created_at = Column(BigInteger, nullable=False) + updated_at = Column(BigInteger, nullable=False) + + +class ApiKeyModel(BaseModel): + id: str + user_id: str + key: str + data: Optional[dict] = None + expires_at: Optional[int] = None + last_used_at: Optional[int] = None + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + #################### # Forms #################### @@ -128,7 +176,7 @@ class UserIdNameResponse(BaseModel): class UserIdNameStatusResponse(BaseModel): id: str name: str - is_active: bool + is_active: bool = False class UserInfoListResponse(BaseModel): @@ -177,20 +225,20 @@ class UsersTable: email: str, profile_image_url: str = "/user.png", role: str = "pending", - oauth_sub: Optional[str] = None, + oauth: Optional[dict] = None, ) -> Optional[UserModel]: with get_db() as db: user = UserModel( **{ "id": id, - "name": name, "email": email, + "name": name, "role": role, "profile_image_url": profile_image_url, "last_active_at": int(time.time()), "created_at": int(time.time()), "updated_at": int(time.time()), - "oauth_sub": oauth_sub, + "oauth": oauth, } ) result = User(**user.model_dump()) @@ -213,8 +261,13 @@ class UsersTable: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) + user = ( + db.query(User) + .join(ApiKey, User.id == ApiKey.user_id) + .filter(ApiKey.key == api_key) + .first() + ) + return UserModel.model_validate(user) if user else None except Exception: return None @@ -226,11 +279,15 @@ class UsersTable: except Exception: return None - def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: + def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) + user = ( + db.query(User) + .filter(User.oauth.contains({provider: {"sub": sub}})) + .first() + ) + return UserModel.model_validate(user) if user else None except Exception: return None @@ -432,7 +489,7 @@ class UsersTable: return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: + def update_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: db.query(User).filter_by(id=id).update( @@ -445,16 +502,35 @@ class UsersTable: except Exception: return None - def update_user_oauth_sub_by_id( - self, id: str, oauth_sub: str + def update_user_oauth_by_id( + self, id: str, provider: str, sub: str ) -> Optional[UserModel]: + """ + Update or insert an OAuth provider/sub pair into the user's oauth JSON field. + Example resulting structure: + { + "google": { "sub": "123" }, + "github": { "sub": "abc" } + } + """ try: with get_db() as db: - db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + user = db.query(User).filter_by(id=id).first() + if not user: + return None + + # Load existing oauth JSON or create empty + oauth = user.oauth or {} + + # Update or insert provider entry + oauth[provider] = {"sub": sub} + + # Persist updated JSON + db.query(User).filter_by(id=id).update({"oauth": oauth}) db.commit() - user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) + except Exception: return None @@ -508,23 +584,45 @@ class UsersTable: except Exception: return False - def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: - try: - with get_db() as db: - result = db.query(User).filter_by(id=id).update({"api_key": api_key}) - db.commit() - return True if result == 1 else False - except Exception: - return False - def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: with get_db() as db: - user = db.query(User).filter_by(id=id).first() - return user.api_key + api_key = db.query(ApiKey).filter_by(user_id=id).first() + return api_key.key if api_key else None except Exception: return None + def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: + try: + with get_db() as db: + db.query(ApiKey).filter_by(user_id=id).delete() + db.commit() + + now = int(time.time()) + new_api_key = ApiKey( + id=f"key_{id}", + user_id=id, + key=api_key, + created_at=now, + updated_at=now, + ) + db.add(new_api_key) + db.commit() + + return True + + except Exception: + return False + + def delete_user_api_key_by_id(self, id: str) -> bool: + try: + with get_db() as db: + db.query(ApiKey).filter_by(user_id=id).delete() + db.commit() + return True + except Exception: + return False + def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: with get_db() as db: users = db.query(User).filter(User.id.in_(user_ids)).all() @@ -538,5 +636,23 @@ class UsersTable: else: return None + def get_active_user_count(self) -> int: + with get_db() as db: + # Consider user active if last_active_at within the last 3 minutes + three_minutes_ago = int(time.time()) - 180 + count = ( + db.query(User).filter(User.last_active_at >= three_minutes_ago).count() + ) + return count + + def is_user_active(self, user_id: str) -> bool: + with get_db() as db: + user = db.query(User).filter_by(id=user_id).first() + if user and user.last_active_at: + # Consider user active if last_active_at within the last 3 minutes + three_minutes_ago = int(time.time()) - 180 + return user.last_active_at >= three_minutes_ago + return False + Users = UsersTable() diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 24cbd9a03f..1b79d84cfd 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -1133,8 +1133,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): # delete api key @router.delete("/api_key", response_model=bool) async def delete_api_key(user=Depends(get_current_user)): - success = Users.update_user_api_key_by_id(user.id, None) - return success + return Users.delete_user_api_key_by_id(user.id) # get api key diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index a3228f5c80..394c9f0009 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -10,7 +10,6 @@ from pydantic import BaseModel from open_webui.socket.main import ( sio, get_user_ids_from_room, - get_active_status_by_user_id, ) from open_webui.models.users import ( UserIdNameResponse, @@ -99,10 +98,7 @@ async def get_channels(user=Depends(get_verified_user)): ] users = [ UserIdNameStatusResponse( - **{ - **user.model_dump(), - "is_active": get_active_status_by_user_id(user.id), - } + **{**user.model_dump(), "is_active": Users.is_user_active(user.id)} ) for user in Users.get_users_by_user_ids(user_ids) ] @@ -284,7 +280,7 @@ async def get_channel_members_by_id( return { "users": [ UserModelResponse( - **user.model_dump(), is_active=get_active_status_by_user_id(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id) ) for user in users ], @@ -316,7 +312,7 @@ async def get_channel_members_by_id( return { "users": [ UserModelResponse( - **user.model_dump(), is_active=get_active_status_by_user_id(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id) ) for user in users ], diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 9b30ba8f20..7c4b801f4d 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -26,12 +26,6 @@ from open_webui.models.users import ( UserUpdateForm, ) - -from open_webui.socket.main import ( - get_active_status_by_user_id, - get_active_user_ids, - get_user_active_status, -) from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR @@ -51,23 +45,6 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() -############################ -# GetActiveUsers -############################ - - -@router.get("/active") -async def get_active_users( - user=Depends(get_verified_user), -): - """ - Get a list of active users. - """ - return { - "user_ids": get_active_user_ids(), - } - - ############################ # GetUsers ############################ @@ -364,7 +341,7 @@ async def update_user_info_by_session_user( class UserActiveResponse(BaseModel): name: str profile_image_url: Optional[str] = None - active: Optional[bool] = None + is_active: bool model_config = ConfigDict(extra="allow") @@ -390,7 +367,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): **{ "id": user.id, "name": user.name, - "active": get_active_status_by_user_id(user_id), + "is_active": Users.is_user_active(user_id), } ) else: @@ -457,7 +434,7 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u @router.get("/{user_id}/active", response_model=dict) async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): return { - "active": get_user_active_status(user_id), + "active": Users.is_user_active(user_id), } diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 04b67dd786..84705648d9 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -132,12 +132,6 @@ if WEBSOCKET_MANAGER == "redis": redis_sentinels=redis_sentinels, redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) - USER_POOL = RedisDict( - f"{REDIS_KEY_PREFIX}:user_pool", - redis_url=WEBSOCKET_REDIS_URL, - redis_sentinels=redis_sentinels, - redis_cluster=WEBSOCKET_REDIS_CLUSTER, - ) USAGE_POOL = RedisDict( f"{REDIS_KEY_PREFIX}:usage_pool", redis_url=WEBSOCKET_REDIS_URL, @@ -159,7 +153,6 @@ else: MODELS = {} SESSION_POOL = {} - USER_POOL = {} USAGE_POOL = {} aquire_func = release_func = renew_func = lambda: True @@ -235,16 +228,6 @@ def get_models_in_use(): return models_in_use -def get_active_user_ids(): - """Get the list of active user IDs.""" - return list(USER_POOL.keys()) - - -def get_user_active_status(user_id): - """Check if a user is currently active.""" - return user_id in USER_POOL - - def get_user_id_from_session_pool(sid): user = SESSION_POOL.get(sid) if user: @@ -270,12 +253,6 @@ def get_user_ids_from_room(room): return active_user_ids -def get_active_status_by_user_id(user_id): - if user_id in USER_POOL: - return True - return False - - @sio.on("usage") async def usage(sid, data): if sid in SESSION_POOL: @@ -303,11 +280,6 @@ async def connect(sid, environ, auth): SESSION_POOL[sid] = user.model_dump( exclude=["date_of_birth", "bio", "gender"] ) - if user.id in USER_POOL: - USER_POOL[user.id] = USER_POOL[user.id] + [sid] - else: - USER_POOL[user.id] = [sid] - await sio.enter_room(sid, f"user:{user.id}") @@ -326,11 +298,15 @@ async def user_join(sid, data): if not user: return - SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"]) - if user.id in USER_POOL: - USER_POOL[user.id] = USER_POOL[user.id] + [sid] - else: - USER_POOL[user.id] = [sid] + SESSION_POOL[sid] = user.model_dump( + exclude=[ + "profile_image_url", + "profile_banner_image_url", + "date_of_birth", + "bio", + "gender", + ] + ) await sio.enter_room(sid, f"user:{user.id}") # Join all the channels @@ -341,6 +317,13 @@ async def user_join(sid, data): return {"id": user.id, "name": user.name} +@sio.on("heartbeat") +async def heartbeat(sid, data): + user = SESSION_POOL.get(sid) + if user: + Users.update_last_active_by_id(user["id"]) + + @sio.on("join-channels") async def join_channel(sid, data): auth = data["auth"] if "auth" in data else None @@ -669,13 +652,6 @@ async def disconnect(sid): if sid in SESSION_POOL: user = SESSION_POOL[sid] del SESSION_POOL[sid] - - user_id = user["id"] - USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid] - - if len(USER_POOL[user_id]) == 0: - del USER_POOL[user_id] - await YDOC_MANAGER.remove_user_from_all_documents(sid) else: pass diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index f3069a093f..3f05256c70 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -344,9 +344,7 @@ async def get_current_user( # Refresh the user's last active timestamp asynchronously # to prevent blocking the request if background_tasks: - background_tasks.add_task( - Users.update_user_last_active_by_id, user.id - ) + background_tasks.add_task(Users.update_last_active_by_id, user.id) return user else: raise HTTPException( @@ -397,8 +395,7 @@ def get_current_user_by_api_key(request, api_key: str): current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.auth.type", "api_key") - Users.update_user_last_active_by_id(user.id) - + Users.update_last_active_by_id(user.id) return user diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index cc2de8e1c7..dc45daca0e 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -32,7 +32,6 @@ from open_webui.models.users import Users from open_webui.socket.main import ( get_event_call, get_event_emitter, - get_active_status_by_user_id, ) from open_webui.routers.tasks import ( generate_queries, @@ -1915,7 +1914,7 @@ async def process_chat_response( ) # Send a webhook notification if the user is not active - if not get_active_status_by_user_id(user.id): + if not Users.is_user_active(user.id): webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: await post_webhook( @@ -3210,7 +3209,7 @@ async def process_chat_response( ) # Send a webhook notification if the user is not active - if not get_active_status_by_user_id(user.id): + if not Users.is_user_active(user.id): webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: await post_webhook( diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index f8a924e8d0..6bd955e90c 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1329,7 +1329,10 @@ class OAuthManager: log.warning(f"OAuth callback failed, sub is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" + oauth_data = {} + oauth_data[provider] = { + "sub": sub, + } # Email extraction email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM @@ -1376,12 +1379,12 @@ class OAuthManager: log.warning(f"Error fetching GitHub email: {e}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) elif ENABLE_OAUTH_EMAIL_FALLBACK: - email = f"{provider_sub}.local" + email = f"{provider}@{sub}.local" else: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - email = email.lower() + email = email.lower() # If allowed domains are configured, check if the email domain is in the list if ( "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS @@ -1394,7 +1397,7 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) + user = Users.get_user_by_oauth_sub(provider, sub) if not user: # If the user does not exist, check if merging is enabled if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: @@ -1402,7 +1405,7 @@ class OAuthManager: user = Users.get_user_by_email(email) if user: # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) + Users.update_user_oauth_by_id(user.id, provider, sub) if user: determined_role = self.get_user_role(user, user_data) @@ -1461,7 +1464,7 @@ class OAuthManager: name=name, profile_image_url=picture_url, role=self.get_user_role(None, user_data), - oauth_sub=provider_sub, + oauth=oauth_data, ) if auth_manager_config.WEBHOOK_URL: diff --git a/backend/open_webui/utils/telemetry/metrics.py b/backend/open_webui/utils/telemetry/metrics.py index 85bd418844..d935ddaafa 100644 --- a/backend/open_webui/utils/telemetry/metrics.py +++ b/backend/open_webui/utils/telemetry/metrics.py @@ -45,7 +45,6 @@ from open_webui.env import ( OTEL_METRICS_OTLP_SPAN_EXPORTER, OTEL_METRICS_EXPORTER_OTLP_INSECURE, ) -from open_webui.socket.main import get_active_user_ids from open_webui.models.users import Users _EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds @@ -135,7 +134,7 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None: ) -> Sequence[metrics.Observation]: return [ metrics.Observation( - value=len(get_active_user_ids()), + value=Users.get_active_user_count(), ) ] diff --git a/src/lib/components/admin/Users/UserList/EditUserModal.svelte b/src/lib/components/admin/Users/UserList/EditUserModal.svelte index 9adbac0e4f..f73551219a 100644 --- a/src/lib/components/admin/Users/UserList/EditUserModal.svelte +++ b/src/lib/components/admin/Users/UserList/EditUserModal.svelte @@ -180,12 +180,17 @@ - {#if _user?.oauth_sub} + {#if _user?.oauth}