Merge pull request #19573 from open-webui/update-user-table

refac/db: update user table
This commit is contained in:
Tim Baek 2025-11-28 07:55:00 -05:00 committed by GitHub
commit fc06c16dd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 495 additions and 159 deletions

View file

@ -66,7 +66,6 @@ from open_webui.socket.main import (
periodic_usage_pool_cleanup, periodic_usage_pool_cleanup,
get_event_emitter, get_event_emitter,
get_models_in_use, get_models_in_use,
get_active_user_ids,
) )
from open_webui.routers import ( from open_webui.routers import (
audio, audio,
@ -2021,7 +2020,10 @@ async def get_current_usage(user=Depends(get_verified_user)):
This is an experimental endpoint and subject to change. This is an experimental endpoint and subject to change.
""" """
try: try:
return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()} return {
"model_ids": get_models_in_use(),
"user_count": Users.get_active_user_count(),
}
except Exception as e: except Exception as e:
log.error(f"Error getting usage statistics: {e}") log.error(f"Error getting usage statistics: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error") raise HTTPException(status_code=500, detail="Internal Server Error")

View file

@ -0,0 +1,251 @@
"""Update user table
Revision ID: b10670c03dd5
Revises: 2f1211949ecc
Create Date: 2025-11-28 04:55:31.737538
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import json
import time
# revision identifiers, used by Alembic.
revision: str = "b10670c03dd5"
down_revision: Union[str, None] = "2f1211949ecc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
"""
SQLite requires manual removal of any indexes referencing a column
before ALTER TABLE ... DROP COLUMN can succeed.
"""
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
for idx in indexes:
index_name = idx[1] # index name
# Get indexed columns
idx_info = conn.execute(
sa.text(f"PRAGMA index_info('{index_name}')")
).fetchall()
indexed_cols = [row[2] for row in idx_info] # col names
if column_name in indexed_cols:
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
def _convert_column_to_json(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
# SQLite cannot ALTER COLUMN → must recreate column
if dialect == "sqlite":
# 1. Add temporary column
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
# 2. Load old data
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for row in rows:
uid, raw = row
if raw is None:
parsed = None
else:
try:
parsed = json.loads(raw)
except Exception:
parsed = None # fallback safe behavior
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
{"val": json.dumps(parsed) if parsed else None, "id": uid},
)
# 3. Drop old TEXT column
op.drop_column(table, column)
# 4. Rename new JSON column → original name
op.alter_column(table, f"{column}_json", new_column_name=column)
else:
# PostgreSQL supports direct CAST
op.alter_column(
table,
column,
type_=sa.JSON(),
postgresql_using=f"{column}::json",
)
def _convert_column_to_text(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
if dialect == "sqlite":
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for uid, raw in rows:
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
{"val": json.dumps(raw) if raw else None, "id": uid},
)
op.drop_column(table, column)
op.alter_column(table, f"{column}_text", new_column_name=column)
else:
op.alter_column(
table,
column,
type_=sa.Text(),
postgresql_using=f"to_json({column})::text",
)
def upgrade() -> None:
op.add_column(
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
)
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
op.add_column(
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
)
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
# Convert info (TEXT/JSONField) → JSON
_convert_column_to_json("user", "info")
# Convert settings (TEXT/JSONField) → JSON
_convert_column_to_json("user", "settings")
op.create_table(
"api_key",
sa.Column("id", sa.Text(), primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
sa.Column("key", sa.Text(), unique=True, nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
).fetchall()
for uid, oauth_sub in users:
if oauth_sub:
# Example formats supported:
# provider@sub
# plain sub (stored as {"oidc": {"sub": sub}})
if "@" in oauth_sub:
provider, sub = oauth_sub.split("@", 1)
else:
provider, sub = "oidc", oauth_sub
oauth_json = json.dumps({provider: {"sub": sub}})
conn.execute(
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
{"oauth": oauth_json, "id": uid},
)
users_with_keys = conn.execute(
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
).fetchall()
now = int(time.time())
for uid, api_key in users_with_keys:
if api_key:
conn.execute(
sa.text(
"""
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
VALUES (:id, :user_id, :key, :created_at, :updated_at)
"""
),
{
"id": f"key_{uid}",
"user_id": uid,
"key": api_key,
"created_at": now,
"updated_at": now,
},
)
if conn.dialect.name == "sqlite":
_drop_sqlite_indexes_for_column("user", "api_key", conn)
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("api_key")
batch_op.drop_column("oauth_sub")
def downgrade() -> None:
# --- 1. Restore old oauth_sub column ---
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
).fetchall()
for uid, oauth in users:
try:
data = json.loads(oauth)
provider = list(data.keys())[0]
sub = data[provider].get("sub")
oauth_sub = f"{provider}@{sub}"
except Exception:
oauth_sub = None
conn.execute(
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
{"oauth_sub": oauth_sub, "id": uid},
)
op.drop_column("user", "oauth")
# --- 2. Restore api_key field ---
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
# Restore values from api_key
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
for uid, key in keys:
conn.execute(
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
{"key": key, "id": uid},
)
# Drop new table
op.drop_table("api_key")
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("profile_banner_image_url")
batch_op.drop_column("timezone")
batch_op.drop_column("presence_state")
batch_op.drop_column("status_emoji")
batch_op.drop_column("status_message")
batch_op.drop_column("status_expires_at")
# Convert info (JSON) → TEXT
_convert_column_to_text("user", "info")
# Convert settings (JSON) → TEXT
_convert_column_to_text("user", "settings")

View file

@ -88,7 +88,7 @@ class AuthsTable:
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth: Optional[dict] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
@ -102,7 +102,7 @@ class AuthsTable:
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth=oauth
) )
db.commit() db.commit()

View file

@ -11,7 +11,17 @@ from open_webui.utils.misc import throttle
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Date, exists, select from sqlalchemy import (
BigInteger,
JSON,
Column,
String,
Boolean,
Text,
Date,
exists,
select,
)
from sqlalchemy import or_, case from sqlalchemy import or_, case
import datetime import datetime
@ -21,59 +31,71 @@ import datetime
#################### ####################
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True, unique=True)
name = Column(String)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
profile_image_url = Column(Text)
bio = Column(Text, nullable=True)
gender = Column(Text, nullable=True)
date_of_birth = Column(Date, nullable=True)
info = Column(JSONField, nullable=True)
settings = Column(JSONField, nullable=True)
api_key = Column(String, nullable=True, unique=True)
oauth_sub = Column(Text, unique=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class UserSettings(BaseModel): class UserSettings(BaseModel):
ui: Optional[dict] = {} ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
pass pass
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True, unique=True)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
name = Column(String)
profile_image_url = Column(Text)
profile_banner_image_url = Column(Text, nullable=True)
bio = Column(Text, nullable=True)
gender = Column(Text, nullable=True)
date_of_birth = Column(Date, nullable=True)
timezone = Column(String, nullable=True)
presence_state = Column(String, nullable=True)
status_emoji = Column(String, nullable=True)
status_message = Column(Text, nullable=True)
status_expires_at = Column(BigInteger, nullable=True)
info = Column(JSON, nullable=True)
settings = Column(JSON, nullable=True)
oauth = Column(JSON, nullable=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class UserModel(BaseModel): class UserModel(BaseModel):
id: str id: str
name: str
email: str email: str
username: Optional[str] = None username: Optional[str] = None
role: str = "pending" role: str = "pending"
name: str
profile_image_url: str profile_image_url: str
profile_banner_image_url: Optional[str] = None
bio: Optional[str] = None bio: Optional[str] = None
gender: Optional[str] = None gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None date_of_birth: Optional[datetime.date] = None
timezone: Optional[str] = None
presence_state: Optional[str] = None
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
info: Optional[dict] = None info: Optional[dict] = None
settings: Optional[UserSettings] = None settings: Optional[UserSettings] = None
api_key: Optional[str] = None oauth: Optional[dict] = None
oauth_sub: Optional[str] = None
last_active_at: int # timestamp in epoch last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@ -82,6 +104,32 @@ class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ApiKey(Base):
__tablename__ = "api_key"
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
key = Column(Text, unique=True, nullable=False)
data = Column(JSON, nullable=True)
expires_at = Column(BigInteger, nullable=True)
last_used_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class ApiKeyModel(BaseModel):
id: str
user_id: str
key: str
data: Optional[dict] = None
expires_at: Optional[int] = None
last_used_at: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
#################### ####################
@ -128,7 +176,7 @@ class UserIdNameResponse(BaseModel):
class UserIdNameStatusResponse(BaseModel): class UserIdNameStatusResponse(BaseModel):
id: str id: str
name: str name: str
is_active: bool is_active: bool = False
class UserInfoListResponse(BaseModel): class UserInfoListResponse(BaseModel):
@ -177,20 +225,20 @@ class UsersTable:
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth: Optional[dict] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
"name": name,
"email": email, "email": email,
"name": name,
"role": role, "role": role,
"profile_image_url": profile_image_url, "profile_image_url": profile_image_url,
"last_active_at": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"oauth_sub": oauth_sub, "oauth": oauth,
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
@ -213,8 +261,13 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first() user = (
return UserModel.model_validate(user) db.query(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key)
.first()
)
return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
@ -226,11 +279,15 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first() user = (
return UserModel.model_validate(user) db.query(User)
.filter(User.oauth.contains({provider: {"sub": sub}}))
.first()
)
return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
@ -432,7 +489,7 @@ class UsersTable:
return None return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
@ -445,16 +502,35 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_oauth_sub_by_id( def update_user_oauth_by_id(
self, id: str, oauth_sub: str self, id: str, provider: str, sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
{
"google": { "sub": "123" },
"github": { "sub": "abc" }
}
"""
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) user = db.query(User).filter_by(id=id).first()
if not user:
return None
# Load existing oauth JSON or create empty
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {"sub": sub}
# Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth})
db.commit() db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
@ -508,23 +584,45 @@ class UsersTable:
except Exception: except Exception:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except Exception:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(id=id).first() api_key = db.query(ApiKey).filter_by(user_id=id).first()
return user.api_key return api_key.key if api_key else None
except Exception: except Exception:
return None return None
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
now = int(time.time())
new_api_key = ApiKey(
id=f"key_{id}",
user_id=id,
key=api_key,
created_at=now,
updated_at=now,
)
db.add(new_api_key)
db.commit()
return True
except Exception:
return False
def delete_user_api_key_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
return True
except Exception:
return False
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db: with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
@ -538,5 +636,23 @@ class UsersTable:
else: else:
return None return None
def get_active_user_count(self) -> int:
with get_db() as db:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
count = (
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
)
return count
def is_user_active(self, user_id: str) -> bool:
with get_db() as db:
user = db.query(User).filter_by(id=user_id).first()
if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
return user.last_active_at >= three_minutes_ago
return False
Users = UsersTable() Users = UsersTable()

View file

@ -1133,8 +1133,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
# delete api key # delete api key
@router.delete("/api_key", response_model=bool) @router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)): async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None) return Users.delete_user_api_key_by_id(user.id)
return success
# get api key # get api key

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel
from open_webui.socket.main import ( from open_webui.socket.main import (
sio, sio,
get_user_ids_from_room, get_user_ids_from_room,
get_active_status_by_user_id,
) )
from open_webui.models.users import ( from open_webui.models.users import (
UserIdNameResponse, UserIdNameResponse,
@ -99,10 +98,7 @@ async def get_channels(user=Depends(get_verified_user)):
] ]
users = [ users = [
UserIdNameStatusResponse( UserIdNameStatusResponse(
**{ **{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
**user.model_dump(),
"is_active": get_active_status_by_user_id(user.id),
}
) )
for user in Users.get_users_by_user_ids(user_ids) for user in Users.get_users_by_user_ids(user_ids)
] ]
@ -284,7 +280,7 @@ async def get_channel_members_by_id(
return { return {
"users": [ "users": [
UserModelResponse( UserModelResponse(
**user.model_dump(), is_active=get_active_status_by_user_id(user.id) **user.model_dump(), is_active=Users.is_user_active(user.id)
) )
for user in users for user in users
], ],
@ -316,7 +312,7 @@ async def get_channel_members_by_id(
return { return {
"users": [ "users": [
UserModelResponse( UserModelResponse(
**user.model_dump(), is_active=get_active_status_by_user_id(user.id) **user.model_dump(), is_active=Users.is_user_active(user.id)
) )
for user in users for user in users
], ],

View file

@ -26,12 +26,6 @@ from open_webui.models.users import (
UserUpdateForm, UserUpdateForm,
) )
from open_webui.socket.main import (
get_active_status_by_user_id,
get_active_user_ids,
get_user_active_status,
)
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
@ -51,23 +45,6 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################
# GetActiveUsers
############################
@router.get("/active")
async def get_active_users(
user=Depends(get_verified_user),
):
"""
Get a list of active users.
"""
return {
"user_ids": get_active_user_ids(),
}
############################ ############################
# GetUsers # GetUsers
############################ ############################
@ -364,7 +341,7 @@ async def update_user_info_by_session_user(
class UserActiveResponse(BaseModel): class UserActiveResponse(BaseModel):
name: str name: str
profile_image_url: Optional[str] = None profile_image_url: Optional[str] = None
active: Optional[bool] = None is_active: bool
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@ -390,7 +367,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
**{ **{
"id": user.id, "id": user.id,
"name": user.name, "name": user.name,
"active": get_active_status_by_user_id(user_id), "is_active": Users.is_user_active(user_id),
} }
) )
else: else:
@ -457,7 +434,7 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
@router.get("/{user_id}/active", response_model=dict) @router.get("/{user_id}/active", response_model=dict)
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
return { return {
"active": get_user_active_status(user_id), "active": Users.is_user_active(user_id),
} }

View file

@ -132,12 +132,6 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER, redis_cluster=WEBSOCKET_REDIS_CLUSTER,
) )
USER_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
USAGE_POOL = RedisDict( USAGE_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:usage_pool", f"{REDIS_KEY_PREFIX}:usage_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
@ -159,7 +153,6 @@ else:
MODELS = {} MODELS = {}
SESSION_POOL = {} SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
aquire_func = release_func = renew_func = lambda: True aquire_func = release_func = renew_func = lambda: True
@ -235,16 +228,6 @@ def get_models_in_use():
return models_in_use return models_in_use
def get_active_user_ids():
"""Get the list of active user IDs."""
return list(USER_POOL.keys())
def get_user_active_status(user_id):
"""Check if a user is currently active."""
return user_id in USER_POOL
def get_user_id_from_session_pool(sid): def get_user_id_from_session_pool(sid):
user = SESSION_POOL.get(sid) user = SESSION_POOL.get(sid)
if user: if user:
@ -270,12 +253,6 @@ def get_user_ids_from_room(room):
return active_user_ids return active_user_ids
def get_active_status_by_user_id(user_id):
if user_id in USER_POOL:
return True
return False
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
if sid in SESSION_POOL: if sid in SESSION_POOL:
@ -303,11 +280,6 @@ async def connect(sid, environ, auth):
SESSION_POOL[sid] = user.model_dump( SESSION_POOL[sid] = user.model_dump(
exclude=["date_of_birth", "bio", "gender"] exclude=["date_of_birth", "bio", "gender"]
) )
if user.id in USER_POOL:
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
else:
USER_POOL[user.id] = [sid]
await sio.enter_room(sid, f"user:{user.id}") await sio.enter_room(sid, f"user:{user.id}")
@ -326,11 +298,15 @@ async def user_join(sid, data):
if not user: if not user:
return return
SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"]) SESSION_POOL[sid] = user.model_dump(
if user.id in USER_POOL: exclude=[
USER_POOL[user.id] = USER_POOL[user.id] + [sid] "profile_image_url",
else: "profile_banner_image_url",
USER_POOL[user.id] = [sid] "date_of_birth",
"bio",
"gender",
]
)
await sio.enter_room(sid, f"user:{user.id}") await sio.enter_room(sid, f"user:{user.id}")
# Join all the channels # Join all the channels
@ -341,6 +317,13 @@ async def user_join(sid, data):
return {"id": user.id, "name": user.name} return {"id": user.id, "name": user.name}
@sio.on("heartbeat")
async def heartbeat(sid, data):
user = SESSION_POOL.get(sid)
if user:
Users.update_last_active_by_id(user["id"])
@sio.on("join-channels") @sio.on("join-channels")
async def join_channel(sid, data): async def join_channel(sid, data):
auth = data["auth"] if "auth" in data else None auth = data["auth"] if "auth" in data else None
@ -669,13 +652,6 @@ async def disconnect(sid):
if sid in SESSION_POOL: if sid in SESSION_POOL:
user = SESSION_POOL[sid] user = SESSION_POOL[sid]
del SESSION_POOL[sid] del SESSION_POOL[sid]
user_id = user["id"]
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await YDOC_MANAGER.remove_user_from_all_documents(sid) await YDOC_MANAGER.remove_user_from_all_documents(sid)
else: else:
pass pass

View file

@ -344,9 +344,7 @@ async def get_current_user(
# Refresh the user's last active timestamp asynchronously # Refresh the user's last active timestamp asynchronously
# to prevent blocking the request # to prevent blocking the request
if background_tasks: if background_tasks:
background_tasks.add_task( background_tasks.add_task(Users.update_last_active_by_id, user.id)
Users.update_user_last_active_by_id, user.id
)
return user return user
else: else:
raise HTTPException( raise HTTPException(
@ -397,8 +395,7 @@ def get_current_user_by_api_key(request, api_key: str):
current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key") current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id) Users.update_last_active_by_id(user.id)
return user return user

View file

@ -32,7 +32,6 @@ from open_webui.models.users import Users
from open_webui.socket.main import ( from open_webui.socket.main import (
get_event_call, get_event_call,
get_event_emitter, get_event_emitter,
get_active_status_by_user_id,
) )
from open_webui.routers.tasks import ( from open_webui.routers.tasks import (
generate_queries, generate_queries,
@ -1915,7 +1914,7 @@ async def process_chat_response(
) )
# Send a webhook notification if the user is not active # Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id): if not Users.is_user_active(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url: if webhook_url:
await post_webhook( await post_webhook(
@ -3210,7 +3209,7 @@ async def process_chat_response(
) )
# Send a webhook notification if the user is not active # Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id): if not Users.is_user_active(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url: if webhook_url:
await post_webhook( await post_webhook(

View file

@ -1329,7 +1329,10 @@ class OAuthManager:
log.warning(f"OAuth callback failed, sub is missing: {user_data}") log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}" oauth_data = {}
oauth_data[provider] = {
"sub": sub,
}
# Email extraction # Email extraction
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
@ -1376,12 +1379,12 @@ class OAuthManager:
log.warning(f"Error fetching GitHub email: {e}") log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
elif ENABLE_OAUTH_EMAIL_FALLBACK: elif ENABLE_OAUTH_EMAIL_FALLBACK:
email = f"{provider_sub}.local" email = f"{provider}@{sub}.local"
else: else:
log.warning(f"OAuth callback failed, email is missing: {user_data}") log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
email = email.lower()
email = email.lower()
# If allowed domains are configured, check if the email domain is in the list # If allowed domains are configured, check if the email domain is in the list
if ( if (
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
@ -1394,7 +1397,7 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists # Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub) user = Users.get_user_by_oauth_sub(provider, sub)
if not user: if not user:
# If the user does not exist, check if merging is enabled # If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
@ -1402,7 +1405,7 @@ class OAuthManager:
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
if user: if user:
# Update the user with the new oauth sub # Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub) Users.update_user_oauth_by_id(user.id, provider, sub)
if user: if user:
determined_role = self.get_user_role(user, user_data) determined_role = self.get_user_role(user, user_data)
@ -1461,7 +1464,7 @@ class OAuthManager:
name=name, name=name,
profile_image_url=picture_url, profile_image_url=picture_url,
role=self.get_user_role(None, user_data), role=self.get_user_role(None, user_data),
oauth_sub=provider_sub, oauth=oauth_data,
) )
if auth_manager_config.WEBHOOK_URL: if auth_manager_config.WEBHOOK_URL:

View file

@ -45,7 +45,6 @@ from open_webui.env import (
OTEL_METRICS_OTLP_SPAN_EXPORTER, OTEL_METRICS_OTLP_SPAN_EXPORTER,
OTEL_METRICS_EXPORTER_OTLP_INSECURE, OTEL_METRICS_EXPORTER_OTLP_INSECURE,
) )
from open_webui.socket.main import get_active_user_ids
from open_webui.models.users import Users from open_webui.models.users import Users
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds _EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
@ -135,7 +134,7 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None:
) -> Sequence[metrics.Observation]: ) -> Sequence[metrics.Observation]:
return [ return [
metrics.Observation( metrics.Observation(
value=len(get_active_user_ids()), value=Users.get_active_user_count(),
) )
] ]

View file

@ -180,12 +180,17 @@
</div> </div>
</div> </div>
{#if _user?.oauth_sub} {#if _user?.oauth}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('OAuth ID')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('OAuth ID')}</div>
<div class="flex-1 text-sm break-all mb-1"> <div class="flex-1 text-sm break-all mb-1 flex flex-col space-y-1">
{_user.oauth_sub ?? ''} {#each Object.keys(_user.oauth) as key}
<div>
<span class="text-gray-500">{key}</span>
<span class="">{_user.oauth[key]?.sub}</span>
</div>
{/each}
</div> </div>
</div> </div>
{/if} {/if}

View file

@ -23,7 +23,7 @@
</div> </div>
<div class=" flex items-center gap-2"> <div class=" flex items-center gap-2">
{#if user?.active} {#if user?.is_active}
<div> <div>
<span class="relative flex size-2"> <span class="relative flex size-2">
<span <span

View file

@ -222,7 +222,7 @@
</DropdownMenu.Item> </DropdownMenu.Item>
{#if showActiveUsers && usage} {#if showActiveUsers && usage}
{#if usage?.user_ids?.length > 0} {#if usage?.user_count}
<hr class=" border-gray-50 dark:border-gray-800 my-1 p-0" /> <hr class=" border-gray-50 dark:border-gray-800 my-1 p-0" />
<Tooltip <Tooltip
@ -250,7 +250,7 @@
{$i18n.t('Active Users')}: {$i18n.t('Active Users')}:
</span> </span>
<span class=" font-semibold"> <span class=" font-semibold">
{usage?.user_ids?.length} {usage?.user_count}
</span> </span>
</div> </div>
</div> </div>

View file

@ -90,6 +90,8 @@
let showRefresh = false; let showRefresh = false;
let heartbeatInterval = null;
const BREAKPOINT = 768; const BREAKPOINT = 768;
const setupSocket = async (enableWebsocket) => { const setupSocket = async (enableWebsocket) => {
@ -126,6 +128,14 @@
} }
} }
// Send heartbeat every 30 seconds
heartbeatInterval = setInterval(() => {
if (_socket.connected) {
console.log('Sending heartbeat');
_socket.emit('heartbeat', {});
}
}, 30000);
if (deploymentId !== null) { if (deploymentId !== null) {
WEBUI_DEPLOYMENT_ID.set(deploymentId); WEBUI_DEPLOYMENT_ID.set(deploymentId);
} }
@ -154,6 +164,12 @@
_socket.on('disconnect', (reason, details) => { _socket.on('disconnect', (reason, details) => {
console.log(`Socket ${_socket.id} disconnected due to ${reason}`); console.log(`Socket ${_socket.id} disconnected due to ${reason}`);
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
heartbeatInterval = null;
}
if (details) { if (details) {
console.log('Additional details:', details); console.log('Additional details:', details);
} }