refac: oauth auth type in openai connection

This commit is contained in:
Timothy Jaeryang Baek 2025-09-08 19:42:50 +04:00
parent 474df5e534
commit 2b2d123531
3 changed files with 121 additions and 122 deletions

View file

@ -119,6 +119,74 @@ def openai_reasoning_model_handler(payload):
return payload return payload
def get_headers_and_cookies(
request: Request,
url,
key=None,
config=None,
metadata: Optional[dict] = None,
user: UserModel = None,
):
cookies = {}
headers = {
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
token = None
auth_type = config.get("auth_type")
if auth_type == "bearer" or auth_type is None:
# Default to bearer if not specified
token = f"{key}"
elif auth_type == "none":
token = None
elif auth_type == "session":
cookies = request.cookies
token = request.state.token.credentials
elif auth_type == "oauth":
cookies = request.cookies
oauth_token = None
try:
oauth_token = request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
if oauth_token:
token = f"{oauth_token.get('access_token', '')}"
if token:
headers["Authorization"] = f"Bearer {token}"
return headers, cookies
########################################## ##########################################
# #
# API routes # API routes
@ -210,34 +278,23 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
headers, cookies = get_headers_and_cookies(
request, url, key, api_config, user=user
)
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{url}/audio/speech", url=f"{url}/audio/speech",
data=body, data=body,
headers={ headers=headers,
"Content-Type": "application/json", cookies=cookies,
"Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
stream=True, stream=True,
) )
@ -401,7 +458,10 @@ async def get_filtered_models(models, user):
return filtered_models return filtered_models
@cached(ttl=MODELS_CACHE_TTL, key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models") @cached(
ttl=MODELS_CACHE_TTL,
key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models",
)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()") log.info("get_all_models()")
@ -489,19 +549,9 @@ async def get_models(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers = { headers, cookies = get_headers_and_cookies(
"Content-Type": "application/json", request, url, key, api_config, user=user
**( )
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False): if api_config.get("azure", False):
models = { models = {
@ -509,11 +559,10 @@ async def get_models(
"object": "list", "object": "list",
} }
else: else:
headers["Authorization"] = f"Bearer {key}"
async with session.get( async with session.get(
f"{url}/models", f"{url}/models",
headers=headers, headers=headers,
cookies=cookies,
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
if r.status != 200: if r.status != 200:
@ -572,7 +621,9 @@ class ConnectionVerificationForm(BaseModel):
@router.post("/verify") @router.post("/verify")
async def verify_connection( async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user) request: Request,
form_data: ConnectionVerificationForm,
user=Depends(get_admin_user),
): ):
url = form_data.url url = form_data.url
key = form_data.key key = form_data.key
@ -584,19 +635,9 @@ async def verify_connection(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers = { headers, cookies = get_headers_and_cookies(
"Content-Type": "application/json", request, url, key, api_config, user=user
**( )
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False): if api_config.get("azure", False):
headers["api-key"] = key headers["api-key"] = key
@ -605,6 +646,7 @@ async def verify_connection(
async with session.get( async with session.get(
url=f"{url}/openai/models?api-version={api_version}", url=f"{url}/openai/models?api-version={api_version}",
headers=headers, headers=headers,
cookies=cookies,
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
try: try:
@ -624,11 +666,10 @@ async def verify_connection(
return response_data return response_data
else: else:
headers["Authorization"] = f"Bearer {key}"
async with session.get( async with session.get(
f"{url}/models", f"{url}/models",
headers=headers, headers=headers,
cookies=cookies,
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
try: try:
@ -836,32 +877,9 @@ async def generate_chat_completion(
convert_logit_bias_input_to_json(payload["logit_bias"]) convert_logit_bias_input_to_json(payload["logit_bias"])
) )
headers = { headers, cookies = get_headers_and_cookies(
"Content-Type": "application/json", request, url, key, api_config, metadata, user=user
**( )
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False): if api_config.get("azure", False):
api_version = api_config.get("api_version", "2023-03-15-preview") api_version = api_config.get("api_version", "2023-03-15-preview")
@ -871,7 +889,6 @@ async def generate_chat_completion(
request_url = f"{request_url}/chat/completions?api-version={api_version}" request_url = f"{request_url}/chat/completions?api-version={api_version}"
else: else:
request_url = f"{url}/chat/completions" request_url = f"{url}/chat/completions"
headers["Authorization"] = f"Bearer {key}"
payload = json.dumps(payload) payload = json.dumps(payload)
@ -890,6 +907,7 @@ async def generate_chat_completion(
url=request_url, url=request_url,
data=payload, data=payload,
headers=headers, headers=headers,
cookies=cookies,
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
@ -951,31 +969,27 @@ async def embeddings(request: Request, form_data: dict, user):
models = request.app.state.OPENAI_MODELS models = request.app.state.OPENAI_MODELS
if model_id in models: if model_id in models:
idx = models[model_id]["urlIdx"] idx = models[model_id]["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
r = None r = None
session = None session = None
streaming = False streaming = False
headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user)
try: try:
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(trust_env=True)
r = await session.request( r = await session.request(
method="POST", method="POST",
url=f"{url}/embeddings", url=f"{url}/embeddings",
data=body, data=body,
headers={ headers=headers,
"Authorization": f"Bearer {key}", cookies=cookies,
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) )
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
@ -1037,19 +1051,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
streaming = False streaming = False
try: try:
headers = { headers, cookies = get_headers_and_cookies(
"Content-Type": "application/json", request, url, key, api_config, user=user
**( )
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False): if api_config.get("azure", False):
api_version = api_config.get("api_version", "2023-03-15-preview") api_version = api_config.get("api_version", "2023-03-15-preview")
@ -1062,7 +1066,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
request_url = f"{url}/{path}?api-version={api_version}" request_url = f"{url}/{path}?api-version={api_version}"
else: else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}" request_url = f"{url}/{path}"
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(trust_env=True)
@ -1071,6 +1074,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
url=request_url, url=request_url,
data=body, data=body,
headers=headers, headers=headers,
cookies=cookies,
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )

View file

@ -138,12 +138,10 @@ async def get_tools(
elif auth_type == "oauth": elif auth_type == "oauth":
cookies = request.cookies cookies = request.cookies
oauth_token = extra_params.get("__oauth_token__", None) oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = ( headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}" f"Bearer {oauth_token.get('access_token', '')}"
) )
elif auth_type == "request_headers":
cookies = request.cookies
headers.update(dict(request.headers))
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
@ -564,9 +562,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
return data return data
async def get_tool_servers_data( async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
servers: List[Dict[str, Any]], session_token: Optional[str] = None
) -> List[Dict[str, Any]]:
# Prepare list of enabled servers along with their original index # Prepare list of enabled servers along with their original index
server_entries = [] server_entries = []
for idx, server in enumerate(servers): for idx, server in enumerate(servers):
@ -582,8 +578,9 @@ async def get_tool_servers_data(
if auth_type == "bearer": if auth_type == "bearer":
token = server.get("key", "") token = server.get("key", "")
elif auth_type == "session": elif auth_type == "none":
token = session_token # No authentication
pass
id = info.get("id") id = info.get("id")
if not id: if not id:

View file

@ -443,8 +443,6 @@
</div> </div>
{/if} {/if}
<hr class=" border-gray-50 dark:border-gray-850 my-2.5 w-full" />
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="mb-1 flex justify-between"> <div class="mb-1 flex justify-between">
<div <div