mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac
This commit is contained in:
parent
143d3fbce2
commit
9f42b9369f
1 changed files with 16 additions and 7 deletions
|
|
@ -24,8 +24,10 @@ from sqlalchemy import (
|
||||||
Date,
|
Date,
|
||||||
exists,
|
exists,
|
||||||
select,
|
select,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
from sqlalchemy import or_, case
|
from sqlalchemy import or_, case
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
@ -296,14 +298,21 @@ class UsersTable:
|
||||||
|
|
||||||
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db: # type: Session
|
||||||
user = (
|
dialect_name = db.bind.dialect.name
|
||||||
db.query(User)
|
|
||||||
.filter(User.oauth.contains({provider: {"sub": sub}}))
|
query = db.query(User)
|
||||||
.first()
|
if dialect_name == "sqlite":
|
||||||
|
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
query = query.filter(
|
||||||
|
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user = query.first()
|
||||||
return UserModel.model_validate(user) if user else None
|
return UserModel.model_validate(user) if user else None
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
# You may want to log the exception here
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_users(
|
def get_users(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue