From 2b2d123531eb3f42c0e940593832a64e2806240d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 8 Sep 2025 19:42:50 +0400 Subject: [PATCH] refac: oauth auth type in openai connection --- backend/open_webui/routers/openai.py | 222 ++++++++++--------- backend/open_webui/utils/tools.py | 19 +- src/lib/components/AddConnectionModal.svelte | 2 - 3 files changed, 121 insertions(+), 122 deletions(-) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index a94791bdf5..d7b5f75260 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -119,6 +119,74 @@ def openai_reasoning_model_handler(payload): return payload +def get_headers_and_cookies( + request: Request, + url, + key=None, + config=None, + metadata: Optional[dict] = None, + user: UserModel = None, +): + cookies = {} + headers = { + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + **( + {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")} + if metadata and metadata.get("chat_id") + else {} + ), + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + } + + token = None + auth_type = config.get("auth_type") + + if auth_type == "bearer" or auth_type is None: + # Default to bearer if not specified + token = f"{key}" + elif auth_type == "none": + token = None + elif auth_type == "session": + cookies = request.cookies + token = request.state.token.credentials + elif auth_type == "oauth": + cookies = request.cookies + + oauth_token = None + try: + oauth_token = request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + + if oauth_token: + token = f"{oauth_token.get('access_token', '')}" + + if token: + headers["Authorization"] = f"Bearer {token}" + + return headers, cookies + + ########################################## # # API routes @@ -210,34 +278,23 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + + headers, cookies = get_headers_and_cookies( + request, url, key, api_config, user=user + ) r = None try: r = requests.post( url=f"{url}/audio/speech", data=body, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", - **( - { - "HTTP-Referer": "https://openwebui.com/", - "X-Title": "Open WebUI", - } - if "openrouter.ai" in url - else {} - ), - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, + headers=headers, + cookies=cookies, stream=True, ) @@ -401,7 +458,10 @@ async def get_filtered_models(models, user): return filtered_models -@cached(ttl=MODELS_CACHE_TTL, key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models") +@cached( + ttl=MODELS_CACHE_TTL, + key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models", +) async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") @@ -489,19 +549,9 @@ async def get_models( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): models = { @@ -509,11 +559,10 @@ async def get_models( "object": "list", } else: - headers["Authorization"] = f"Bearer {key}" - async with session.get( f"{url}/models", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: @@ -572,7 +621,9 @@ class ConnectionVerificationForm(BaseModel): @router.post("/verify") async def verify_connection( - form_data: ConnectionVerificationForm, user=Depends(get_admin_user) + request: Request, + form_data: ConnectionVerificationForm, + user=Depends(get_admin_user), ): url = form_data.url key = form_data.key @@ -584,19 +635,9 @@ async def verify_connection( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): headers["api-key"] = key @@ -605,6 +646,7 @@ async def verify_connection( async with session.get( url=f"{url}/openai/models?api-version={api_version}", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: try: @@ -624,11 +666,10 @@ async def verify_connection( return response_data else: - headers["Authorization"] = f"Bearer {key}" - async with session.get( f"{url}/models", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: try: @@ -836,32 +877,9 @@ async def generate_chat_completion( convert_logit_bias_input_to_json(payload["logit_bias"]) ) - headers = { - "Content-Type": "application/json", - **( - { - "HTTP-Referer": "https://openwebui.com/", - "X-Title": "Open WebUI", - } - if "openrouter.ai" in url - else {} - ), - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - **( - {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")} - if metadata and metadata.get("chat_id") - else {} - ), - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = get_headers_and_cookies( + request, url, key, api_config, metadata, user=user + ) if api_config.get("azure", False): api_version = api_config.get("api_version", "2023-03-15-preview") @@ -871,7 +889,6 @@ async def generate_chat_completion( request_url = f"{request_url}/chat/completions?api-version={api_version}" else: request_url = f"{url}/chat/completions" - headers["Authorization"] = f"Bearer {key}" payload = json.dumps(payload) @@ -890,6 +907,7 @@ async def generate_chat_completion( url=request_url, data=payload, headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) @@ -951,31 +969,27 @@ async def embeddings(request: Request, form_data: dict, user): models = request.app.state.OPENAI_MODELS if model_id in models: idx = models[model_id]["urlIdx"] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + r = None session = None streaming = False + + headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user) try: session = aiohttp.ClientSession(trust_env=True) r = await session.request( method="POST", url=f"{url}/embeddings", data=body, - headers={ - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS and user - else {} - ), - }, + headers=headers, + cookies=cookies, ) if "text/event-stream" in r.headers.get("Content-Type", ""): @@ -1037,19 +1051,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): streaming = False try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): api_version = api_config.get("api_version", "2023-03-15-preview") @@ -1062,7 +1066,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): request_url = f"{url}/{path}?api-version={api_version}" else: - headers["Authorization"] = f"Bearer {key}" request_url = f"{url}/{path}" session = aiohttp.ClientSession(trust_env=True) @@ -1071,6 +1074,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): url=request_url, data=body, headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 74974e1461..fd441df9e1 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -138,12 +138,10 @@ async def get_tools( elif auth_type == "oauth": cookies = request.cookies oauth_token = extra_params.get("__oauth_token__", None) - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) - elif auth_type == "request_headers": - cookies = request.cookies - headers.update(dict(request.headers)) + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) headers["Content-Type"] = "application/json" @@ -564,9 +562,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: return data -async def get_tool_servers_data( - servers: List[Dict[str, Any]], session_token: Optional[str] = None -) -> List[Dict[str, Any]]: +async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Prepare list of enabled servers along with their original index server_entries = [] for idx, server in enumerate(servers): @@ -582,8 +578,9 @@ async def get_tool_servers_data( if auth_type == "bearer": token = server.get("key", "") - elif auth_type == "session": - token = session_token + elif auth_type == "none": + # No authentication + pass id = info.get("id") if not id: diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index ecf03ee2ce..900dc74d2b 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -443,8 +443,6 @@ {/if} -
-