mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
refac
This commit is contained in:
parent
f71834720e
commit
b5bb6ae177
5 changed files with 45 additions and 12 deletions
|
|
@ -219,6 +219,15 @@ async def generate_function_chat_completion(
|
||||||
__task__ = metadata.get("task", None)
|
__task__ = metadata.get("task", None)
|
||||||
__task_body__ = metadata.get("task_body", None)
|
__task_body__ = metadata.get("task_body", None)
|
||||||
|
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id,
|
||||||
|
request.cookies.get("oauth_session_id", None),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__event_emitter__": __event_emitter__,
|
"__event_emitter__": __event_emitter__,
|
||||||
"__event_call__": __event_call__,
|
"__event_call__": __event_call__,
|
||||||
|
|
@ -230,6 +239,7 @@ async def generate_function_chat_completion(
|
||||||
"__files__": files,
|
"__files__": files,
|
||||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||||
"__metadata__": metadata,
|
"__metadata__": metadata,
|
||||||
|
"__oauth_token__": oauth_token,
|
||||||
"__request__": request,
|
"__request__": request,
|
||||||
}
|
}
|
||||||
extra_params["__tools__"] = await get_tools(
|
extra_params["__tools__"] = await get_tools(
|
||||||
|
|
|
||||||
|
|
@ -1408,6 +1408,14 @@ async def chat_completion(
|
||||||
model_item = form_data.pop("model_item", {})
|
model_item = form_data.pop("model_item", {})
|
||||||
tasks = form_data.pop("background_tasks", None)
|
tasks = form_data.pop("background_tasks", None)
|
||||||
|
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id, request.cookies.get("oauth_session_id", None)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
try:
|
try:
|
||||||
if not model_item.get("direct", False):
|
if not model_item.get("direct", False):
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import Auths
|
from open_webui.models.auths import Auths
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.users import (
|
from open_webui.models.users import (
|
||||||
|
|
@ -340,6 +342,18 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}/oauth/sessions", response_model=Optional[dict])
|
||||||
|
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||||
|
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
||||||
|
if sessions and len(sessions) > 0:
|
||||||
|
return sessions
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetUserProfileImageById
|
# GetUserProfileImageById
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -818,7 +818,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
try:
|
try:
|
||||||
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id, request.cookies.get("oauth_session_id", None)
|
user.id,
|
||||||
|
request.cookies.get("oauth_session_id", None),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
@ -1493,11 +1494,21 @@ async def process_chat_response(
|
||||||
):
|
):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id,
|
||||||
|
request.cookies.get("oauth_session_id", None),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__event_emitter__": event_emitter,
|
"__event_emitter__": event_emitter,
|
||||||
"__event_call__": event_caller,
|
"__event_call__": event_caller,
|
||||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||||
"__metadata__": metadata,
|
"__metadata__": metadata,
|
||||||
|
"__oauth_token__": oauth_token,
|
||||||
"__request__": request,
|
"__request__": request,
|
||||||
"__model__": model,
|
"__model__": model,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -134,17 +134,7 @@ async def get_tools(
|
||||||
)
|
)
|
||||||
elif auth_type == "oauth":
|
elif auth_type == "oauth":
|
||||||
cookies = request.cookies
|
cookies = request.cookies
|
||||||
oauth_token = None
|
oauth_token = extra_params.get("__oauth_token__", None)
|
||||||
try:
|
|
||||||
oauth_token = (
|
|
||||||
request.app.state.oauth_manager.get_oauth_token(
|
|
||||||
user.id,
|
|
||||||
request.cookies.get("oauth_session_id", None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
|
||||||
|
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue