mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
Merge pull request #19573 from open-webui/update-user-table
refac/db: update user table
This commit is contained in:
commit
fc06c16dd4
16 changed files with 495 additions and 159 deletions
|
|
@ -66,7 +66,6 @@ from open_webui.socket.main import (
|
|||
periodic_usage_pool_cleanup,
|
||||
get_event_emitter,
|
||||
get_models_in_use,
|
||||
get_active_user_ids,
|
||||
)
|
||||
from open_webui.routers import (
|
||||
audio,
|
||||
|
|
@ -2021,7 +2020,10 @@ async def get_current_usage(user=Depends(get_verified_user)):
|
|||
This is an experimental endpoint and subject to change.
|
||||
"""
|
||||
try:
|
||||
return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()}
|
||||
return {
|
||||
"model_ids": get_models_in_use(),
|
||||
"user_count": Users.get_active_user_count(),
|
||||
}
|
||||
except Exception as e:
|
||||
log.error(f"Error getting usage statistics: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal Server Error")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,251 @@
|
|||
"""Update user table
|
||||
|
||||
Revision ID: b10670c03dd5
|
||||
Revises: 2f1211949ecc
|
||||
Create Date: 2025-11-28 04:55:31.737538
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
import open_webui.internal.db
|
||||
import json
|
||||
import time
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b10670c03dd5"
|
||||
down_revision: Union[str, None] = "2f1211949ecc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
|
||||
"""
|
||||
SQLite requires manual removal of any indexes referencing a column
|
||||
before ALTER TABLE ... DROP COLUMN can succeed.
|
||||
"""
|
||||
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
|
||||
|
||||
for idx in indexes:
|
||||
index_name = idx[1] # index name
|
||||
# Get indexed columns
|
||||
idx_info = conn.execute(
|
||||
sa.text(f"PRAGMA index_info('{index_name}')")
|
||||
).fetchall()
|
||||
|
||||
indexed_cols = [row[2] for row in idx_info] # col names
|
||||
if column_name in indexed_cols:
|
||||
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
|
||||
|
||||
|
||||
def _convert_column_to_json(table: str, column: str):
|
||||
conn = op.get_bind()
|
||||
dialect = conn.dialect.name
|
||||
|
||||
# SQLite cannot ALTER COLUMN → must recreate column
|
||||
if dialect == "sqlite":
|
||||
# 1. Add temporary column
|
||||
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
|
||||
|
||||
# 2. Load old data
|
||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||
|
||||
for row in rows:
|
||||
uid, raw = row
|
||||
if raw is None:
|
||||
parsed = None
|
||||
else:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
parsed = None # fallback safe behavior
|
||||
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
||||
{"val": json.dumps(parsed) if parsed else None, "id": uid},
|
||||
)
|
||||
|
||||
# 3. Drop old TEXT column
|
||||
op.drop_column(table, column)
|
||||
|
||||
# 4. Rename new JSON column → original name
|
||||
op.alter_column(table, f"{column}_json", new_column_name=column)
|
||||
|
||||
else:
|
||||
# PostgreSQL supports direct CAST
|
||||
op.alter_column(
|
||||
table,
|
||||
column,
|
||||
type_=sa.JSON(),
|
||||
postgresql_using=f"{column}::json",
|
||||
)
|
||||
|
||||
|
||||
def _convert_column_to_text(table: str, column: str):
|
||||
conn = op.get_bind()
|
||||
dialect = conn.dialect.name
|
||||
|
||||
if dialect == "sqlite":
|
||||
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
|
||||
|
||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||
|
||||
for uid, raw in rows:
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
||||
{"val": json.dumps(raw) if raw else None, "id": uid},
|
||||
)
|
||||
|
||||
op.drop_column(table, column)
|
||||
op.alter_column(table, f"{column}_text", new_column_name=column)
|
||||
|
||||
else:
|
||||
op.alter_column(
|
||||
table,
|
||||
column,
|
||||
type_=sa.Text(),
|
||||
postgresql_using=f"to_json({column})::text",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
|
||||
|
||||
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
|
||||
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
|
||||
|
||||
# Convert info (TEXT/JSONField) → JSON
|
||||
_convert_column_to_json("user", "info")
|
||||
# Convert settings (TEXT/JSONField) → JSON
|
||||
_convert_column_to_json("user", "settings")
|
||||
|
||||
op.create_table(
|
||||
"api_key",
|
||||
sa.Column("id", sa.Text(), primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
|
||||
sa.Column("key", sa.Text(), unique=True, nullable=False),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("expires_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
|
||||
).fetchall()
|
||||
|
||||
for uid, oauth_sub in users:
|
||||
if oauth_sub:
|
||||
# Example formats supported:
|
||||
# provider@sub
|
||||
# plain sub (stored as {"oidc": {"sub": sub}})
|
||||
if "@" in oauth_sub:
|
||||
provider, sub = oauth_sub.split("@", 1)
|
||||
else:
|
||||
provider, sub = "oidc", oauth_sub
|
||||
|
||||
oauth_json = json.dumps({provider: {"sub": sub}})
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
||||
{"oauth": oauth_json, "id": uid},
|
||||
)
|
||||
|
||||
users_with_keys = conn.execute(
|
||||
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
|
||||
).fetchall()
|
||||
now = int(time.time())
|
||||
|
||||
for uid, api_key in users_with_keys:
|
||||
if api_key:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
|
||||
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": f"key_{uid}",
|
||||
"user_id": uid,
|
||||
"key": api_key,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
if conn.dialect.name == "sqlite":
|
||||
_drop_sqlite_indexes_for_column("user", "api_key", conn)
|
||||
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
batch_op.drop_column("api_key")
|
||||
batch_op.drop_column("oauth_sub")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# --- 1. Restore old oauth_sub column ---
|
||||
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
|
||||
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
|
||||
).fetchall()
|
||||
|
||||
for uid, oauth in users:
|
||||
try:
|
||||
data = json.loads(oauth)
|
||||
provider = list(data.keys())[0]
|
||||
sub = data[provider].get("sub")
|
||||
oauth_sub = f"{provider}@{sub}"
|
||||
except Exception:
|
||||
oauth_sub = None
|
||||
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
||||
{"oauth_sub": oauth_sub, "id": uid},
|
||||
)
|
||||
|
||||
op.drop_column("user", "oauth")
|
||||
|
||||
# --- 2. Restore api_key field ---
|
||||
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
|
||||
|
||||
# Restore values from api_key
|
||||
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
|
||||
for uid, key in keys:
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
||||
{"key": key, "id": uid},
|
||||
)
|
||||
|
||||
# Drop new table
|
||||
op.drop_table("api_key")
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
batch_op.drop_column("profile_banner_image_url")
|
||||
batch_op.drop_column("timezone")
|
||||
|
||||
batch_op.drop_column("presence_state")
|
||||
batch_op.drop_column("status_emoji")
|
||||
batch_op.drop_column("status_message")
|
||||
batch_op.drop_column("status_expires_at")
|
||||
|
||||
# Convert info (JSON) → TEXT
|
||||
_convert_column_to_text("user", "info")
|
||||
# Convert settings (JSON) → TEXT
|
||||
_convert_column_to_text("user", "settings")
|
||||
|
|
@ -88,7 +88,7 @@ class AuthsTable:
|
|||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
oauth: Optional[dict] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
log.info("insert_new_auth")
|
||||
|
|
@ -102,7 +102,7 @@ class AuthsTable:
|
|||
db.add(result)
|
||||
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth_sub
|
||||
id, name, email, profile_image_url, role, oauth=oauth
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,17 @@ from open_webui.utils.misc import throttle
|
|||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, Date, exists, select
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
Column,
|
||||
String,
|
||||
Boolean,
|
||||
Text,
|
||||
Date,
|
||||
exists,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy import or_, case
|
||||
|
||||
import datetime
|
||||
|
|
@ -21,59 +31,71 @@ import datetime
|
|||
####################
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
name = Column(String)
|
||||
|
||||
email = Column(String)
|
||||
username = Column(String(50), nullable=True)
|
||||
|
||||
role = Column(String)
|
||||
profile_image_url = Column(Text)
|
||||
|
||||
bio = Column(Text, nullable=True)
|
||||
gender = Column(Text, nullable=True)
|
||||
date_of_birth = Column(Date, nullable=True)
|
||||
|
||||
info = Column(JSONField, nullable=True)
|
||||
settings = Column(JSONField, nullable=True)
|
||||
|
||||
api_key = Column(String, nullable=True, unique=True)
|
||||
oauth_sub = Column(Text, unique=True)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
username = Column(String(50), nullable=True)
|
||||
role = Column(String)
|
||||
|
||||
name = Column(String)
|
||||
|
||||
profile_image_url = Column(Text)
|
||||
profile_banner_image_url = Column(Text, nullable=True)
|
||||
|
||||
bio = Column(Text, nullable=True)
|
||||
gender = Column(Text, nullable=True)
|
||||
date_of_birth = Column(Date, nullable=True)
|
||||
timezone = Column(String, nullable=True)
|
||||
|
||||
presence_state = Column(String, nullable=True)
|
||||
status_emoji = Column(String, nullable=True)
|
||||
status_message = Column(Text, nullable=True)
|
||||
status_expires_at = Column(BigInteger, nullable=True)
|
||||
|
||||
info = Column(JSON, nullable=True)
|
||||
settings = Column(JSON, nullable=True)
|
||||
|
||||
oauth = Column(JSON, nullable=True)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
email: str
|
||||
username: Optional[str] = None
|
||||
|
||||
role: str = "pending"
|
||||
|
||||
name: str
|
||||
|
||||
profile_image_url: str
|
||||
profile_banner_image_url: Optional[str] = None
|
||||
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
timezone: Optional[str] = None
|
||||
|
||||
presence_state: Optional[str] = None
|
||||
status_emoji: Optional[str] = None
|
||||
status_message: Optional[str] = None
|
||||
status_expires_at: Optional[int] = None
|
||||
|
||||
info: Optional[dict] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
|
||||
api_key: Optional[str] = None
|
||||
oauth_sub: Optional[str] = None
|
||||
oauth: Optional[dict] = None
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
|
@ -82,6 +104,32 @@ class UserModel(BaseModel):
|
|||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_key"
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
key = Column(Text, unique=True, nullable=False)
|
||||
data = Column(JSON, nullable=True)
|
||||
expires_at = Column(BigInteger, nullable=True)
|
||||
last_used_at = Column(BigInteger, nullable=True)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
|
||||
class ApiKeyModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
key: str
|
||||
data: Optional[dict] = None
|
||||
expires_at: Optional[int] = None
|
||||
last_used_at: Optional[int] = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
|
@ -128,7 +176,7 @@ class UserIdNameResponse(BaseModel):
|
|||
class UserIdNameStatusResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
is_active: bool
|
||||
is_active: bool = False
|
||||
|
||||
|
||||
class UserInfoListResponse(BaseModel):
|
||||
|
|
@ -177,20 +225,20 @@ class UsersTable:
|
|||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
oauth: Optional[dict] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"role": role,
|
||||
"profile_image_url": profile_image_url,
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"oauth_sub": oauth_sub,
|
||||
"oauth": oauth,
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
||||
|
|
@ -213,8 +261,13 @@ class UsersTable:
|
|||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(api_key=api_key).first()
|
||||
return UserModel.model_validate(user)
|
||||
user = (
|
||||
db.query(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.filter(ApiKey.key == api_key)
|
||||
.first()
|
||||
)
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -226,11 +279,15 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||
return UserModel.model_validate(user)
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||
.first()
|
||||
)
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -432,7 +489,7 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
|
|
@ -445,16 +502,35 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_oauth_sub_by_id(
|
||||
self, id: str, oauth_sub: str
|
||||
def update_user_oauth_by_id(
|
||||
self, id: str, provider: str, sub: str
|
||||
) -> Optional[UserModel]:
|
||||
"""
|
||||
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||
Example resulting structure:
|
||||
{
|
||||
"google": { "sub": "123" },
|
||||
"github": { "sub": "abc" }
|
||||
}
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Load existing oauth JSON or create empty
|
||||
oauth = user.oauth or {}
|
||||
|
||||
# Update or insert provider entry
|
||||
oauth[provider] = {"sub": sub}
|
||||
|
||||
# Persist updated JSON
|
||||
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -508,23 +584,45 @@ class UsersTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return user.api_key
|
||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||
return api_key.key if api_key else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
db.commit()
|
||||
|
||||
now = int(time.time())
|
||||
new_api_key = ApiKey(
|
||||
id=f"key_{id}",
|
||||
user_id=id,
|
||||
key=api_key,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(new_api_key)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_user_api_key_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
||||
with get_db() as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
|
|
@ -538,5 +636,23 @@ class UsersTable:
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_active_user_count(self) -> int:
|
||||
with get_db() as db:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
count = (
|
||||
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||
)
|
||||
return count
|
||||
|
||||
def is_user_active(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=user_id).first()
|
||||
if user and user.last_active_at:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
return user.last_active_at >= three_minutes_ago
|
||||
return False
|
||||
|
||||
|
||||
Users = UsersTable()
|
||||
|
|
|
|||
|
|
@ -1133,8 +1133,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
|||
# delete api key
|
||||
@router.delete("/api_key", response_model=bool)
|
||||
async def delete_api_key(user=Depends(get_current_user)):
|
||||
success = Users.update_user_api_key_by_id(user.id, None)
|
||||
return success
|
||||
return Users.delete_user_api_key_by_id(user.id)
|
||||
|
||||
|
||||
# get api key
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from pydantic import BaseModel
|
|||
from open_webui.socket.main import (
|
||||
sio,
|
||||
get_user_ids_from_room,
|
||||
get_active_status_by_user_id,
|
||||
)
|
||||
from open_webui.models.users import (
|
||||
UserIdNameResponse,
|
||||
|
|
@ -99,10 +98,7 @@ async def get_channels(user=Depends(get_verified_user)):
|
|||
]
|
||||
users = [
|
||||
UserIdNameStatusResponse(
|
||||
**{
|
||||
**user.model_dump(),
|
||||
"is_active": get_active_status_by_user_id(user.id),
|
||||
}
|
||||
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
|
||||
)
|
||||
for user in Users.get_users_by_user_ids(user_ids)
|
||||
]
|
||||
|
|
@ -284,7 +280,7 @@ async def get_channel_members_by_id(
|
|||
return {
|
||||
"users": [
|
||||
UserModelResponse(
|
||||
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
|
|
@ -316,7 +312,7 @@ async def get_channel_members_by_id(
|
|||
return {
|
||||
"users": [
|
||||
UserModelResponse(
|
||||
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
|
|
|
|||
|
|
@ -26,12 +26,6 @@ from open_webui.models.users import (
|
|||
UserUpdateForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.socket.main import (
|
||||
get_active_status_by_user_id,
|
||||
get_active_user_ids,
|
||||
get_user_active_status,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
||||
|
||||
|
|
@ -51,23 +45,6 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# GetActiveUsers
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/active")
|
||||
async def get_active_users(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""
|
||||
Get a list of active users.
|
||||
"""
|
||||
return {
|
||||
"user_ids": get_active_user_ids(),
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# GetUsers
|
||||
############################
|
||||
|
|
@ -364,7 +341,7 @@ async def update_user_info_by_session_user(
|
|||
class UserActiveResponse(BaseModel):
|
||||
name: str
|
||||
profile_image_url: Optional[str] = None
|
||||
active: Optional[bool] = None
|
||||
is_active: bool
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
|
|
@ -390,7 +367,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
|||
**{
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"active": get_active_status_by_user_id(user_id),
|
||||
"is_active": Users.is_user_active(user_id),
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
|
@ -457,7 +434,7 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
|
|||
@router.get("/{user_id}/active", response_model=dict)
|
||||
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
return {
|
||||
"active": get_user_active_status(user_id),
|
||||
"active": Users.is_user_active(user_id),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -132,12 +132,6 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USER_POOL = RedisDict(
|
||||
f"{REDIS_KEY_PREFIX}:user_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USAGE_POOL = RedisDict(
|
||||
f"{REDIS_KEY_PREFIX}:usage_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
|
|
@ -159,7 +153,6 @@ else:
|
|||
MODELS = {}
|
||||
|
||||
SESSION_POOL = {}
|
||||
USER_POOL = {}
|
||||
USAGE_POOL = {}
|
||||
|
||||
aquire_func = release_func = renew_func = lambda: True
|
||||
|
|
@ -235,16 +228,6 @@ def get_models_in_use():
|
|||
return models_in_use
|
||||
|
||||
|
||||
def get_active_user_ids():
|
||||
"""Get the list of active user IDs."""
|
||||
return list(USER_POOL.keys())
|
||||
|
||||
|
||||
def get_user_active_status(user_id):
|
||||
"""Check if a user is currently active."""
|
||||
return user_id in USER_POOL
|
||||
|
||||
|
||||
def get_user_id_from_session_pool(sid):
|
||||
user = SESSION_POOL.get(sid)
|
||||
if user:
|
||||
|
|
@ -270,12 +253,6 @@ def get_user_ids_from_room(room):
|
|||
return active_user_ids
|
||||
|
||||
|
||||
def get_active_status_by_user_id(user_id):
|
||||
if user_id in USER_POOL:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@sio.on("usage")
|
||||
async def usage(sid, data):
|
||||
if sid in SESSION_POOL:
|
||||
|
|
@ -303,11 +280,6 @@ async def connect(sid, environ, auth):
|
|||
SESSION_POOL[sid] = user.model_dump(
|
||||
exclude=["date_of_birth", "bio", "gender"]
|
||||
)
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
await sio.enter_room(sid, f"user:{user.id}")
|
||||
|
||||
|
||||
|
|
@ -326,11 +298,15 @@ async def user_join(sid, data):
|
|||
if not user:
|
||||
return
|
||||
|
||||
SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"])
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
SESSION_POOL[sid] = user.model_dump(
|
||||
exclude=[
|
||||
"profile_image_url",
|
||||
"profile_banner_image_url",
|
||||
"date_of_birth",
|
||||
"bio",
|
||||
"gender",
|
||||
]
|
||||
)
|
||||
|
||||
await sio.enter_room(sid, f"user:{user.id}")
|
||||
# Join all the channels
|
||||
|
|
@ -341,6 +317,13 @@ async def user_join(sid, data):
|
|||
return {"id": user.id, "name": user.name}
|
||||
|
||||
|
||||
@sio.on("heartbeat")
|
||||
async def heartbeat(sid, data):
|
||||
user = SESSION_POOL.get(sid)
|
||||
if user:
|
||||
Users.update_last_active_by_id(user["id"])
|
||||
|
||||
|
||||
@sio.on("join-channels")
|
||||
async def join_channel(sid, data):
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
|
|
@ -669,13 +652,6 @@ async def disconnect(sid):
|
|||
if sid in SESSION_POOL:
|
||||
user = SESSION_POOL[sid]
|
||||
del SESSION_POOL[sid]
|
||||
|
||||
user_id = user["id"]
|
||||
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
||||
|
||||
if len(USER_POOL[user_id]) == 0:
|
||||
del USER_POOL[user_id]
|
||||
|
||||
await YDOC_MANAGER.remove_user_from_all_documents(sid)
|
||||
else:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -344,9 +344,7 @@ async def get_current_user(
|
|||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
Users.update_user_last_active_by_id, user.id
|
||||
)
|
||||
background_tasks.add_task(Users.update_last_active_by_id, user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -397,8 +395,7 @@ def get_current_user_by_api_key(request, api_key: str):
|
|||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
Users.update_last_active_by_id(user.id)
|
||||
return user
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from open_webui.models.users import Users
|
|||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
get_active_status_by_user_id,
|
||||
)
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
|
|
@ -1915,7 +1914,7 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
await post_webhook(
|
||||
|
|
@ -3210,7 +3209,7 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
await post_webhook(
|
||||
|
|
|
|||
|
|
@ -1329,7 +1329,10 @@ class OAuthManager:
|
|||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
oauth_data = {}
|
||||
oauth_data[provider] = {
|
||||
"sub": sub,
|
||||
}
|
||||
|
||||
# Email extraction
|
||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||
|
|
@ -1376,12 +1379,12 @@ class OAuthManager:
|
|||
log.warning(f"Error fetching GitHub email: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
elif ENABLE_OAUTH_EMAIL_FALLBACK:
|
||||
email = f"{provider_sub}.local"
|
||||
email = f"{provider}@{sub}.local"
|
||||
else:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
email = email.lower()
|
||||
|
||||
email = email.lower()
|
||||
# If allowed domains are configured, check if the email domain is in the list
|
||||
if (
|
||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||
|
|
@ -1394,7 +1397,7 @@ class OAuthManager:
|
|||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||
user = Users.get_user_by_oauth_sub(provider, sub)
|
||||
if not user:
|
||||
# If the user does not exist, check if merging is enabled
|
||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||
|
|
@ -1402,7 +1405,7 @@ class OAuthManager:
|
|||
user = Users.get_user_by_email(email)
|
||||
if user:
|
||||
# Update the user with the new oauth sub
|
||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||
Users.update_user_oauth_by_id(user.id, provider, sub)
|
||||
|
||||
if user:
|
||||
determined_role = self.get_user_role(user, user_data)
|
||||
|
|
@ -1461,7 +1464,7 @@ class OAuthManager:
|
|||
name=name,
|
||||
profile_image_url=picture_url,
|
||||
role=self.get_user_role(None, user_data),
|
||||
oauth_sub=provider_sub,
|
||||
oauth=oauth_data,
|
||||
)
|
||||
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ from open_webui.env import (
|
|||
OTEL_METRICS_OTLP_SPAN_EXPORTER,
|
||||
OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||
)
|
||||
from open_webui.socket.main import get_active_user_ids
|
||||
from open_webui.models.users import Users
|
||||
|
||||
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
||||
|
|
@ -135,7 +134,7 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None:
|
|||
) -> Sequence[metrics.Observation]:
|
||||
return [
|
||||
metrics.Observation(
|
||||
value=len(get_active_user_ids()),
|
||||
value=Users.get_active_user_count(),
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -180,12 +180,17 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{#if _user?.oauth_sub}
|
||||
{#if _user?.oauth}
|
||||
<div class="flex flex-col w-full">
|
||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('OAuth ID')}</div>
|
||||
|
||||
<div class="flex-1 text-sm break-all mb-1">
|
||||
{_user.oauth_sub ?? ''}
|
||||
<div class="flex-1 text-sm break-all mb-1 flex flex-col space-y-1">
|
||||
{#each Object.keys(_user.oauth) as key}
|
||||
<div>
|
||||
<span class="text-gray-500">{key}</span>
|
||||
<span class="">{_user.oauth[key]?.sub}</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
</div>
|
||||
|
||||
<div class=" flex items-center gap-2">
|
||||
{#if user?.active}
|
||||
{#if user?.is_active}
|
||||
<div>
|
||||
<span class="relative flex size-2">
|
||||
<span
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@
|
|||
</DropdownMenu.Item>
|
||||
|
||||
{#if showActiveUsers && usage}
|
||||
{#if usage?.user_ids?.length > 0}
|
||||
{#if usage?.user_count}
|
||||
<hr class=" border-gray-50 dark:border-gray-800 my-1 p-0" />
|
||||
|
||||
<Tooltip
|
||||
|
|
@ -250,7 +250,7 @@
|
|||
{$i18n.t('Active Users')}:
|
||||
</span>
|
||||
<span class=" font-semibold">
|
||||
{usage?.user_ids?.length}
|
||||
{usage?.user_count}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@
|
|||
|
||||
let showRefresh = false;
|
||||
|
||||
let heartbeatInterval = null;
|
||||
|
||||
const BREAKPOINT = 768;
|
||||
|
||||
const setupSocket = async (enableWebsocket) => {
|
||||
|
|
@ -126,6 +128,14 @@
|
|||
}
|
||||
}
|
||||
|
||||
// Send heartbeat every 30 seconds
|
||||
heartbeatInterval = setInterval(() => {
|
||||
if (_socket.connected) {
|
||||
console.log('Sending heartbeat');
|
||||
_socket.emit('heartbeat', {});
|
||||
}
|
||||
}, 30000);
|
||||
|
||||
if (deploymentId !== null) {
|
||||
WEBUI_DEPLOYMENT_ID.set(deploymentId);
|
||||
}
|
||||
|
|
@ -154,6 +164,12 @@
|
|||
|
||||
_socket.on('disconnect', (reason, details) => {
|
||||
console.log(`Socket ${_socket.id} disconnected due to ${reason}`);
|
||||
|
||||
if (heartbeatInterval) {
|
||||
clearInterval(heartbeatInterval);
|
||||
heartbeatInterval = null;
|
||||
}
|
||||
|
||||
if (details) {
|
||||
console.log('Additional details:', details);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue