mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: oauth auth type in openai connection
This commit is contained in:
parent
474df5e534
commit
2b2d123531
3 changed files with 121 additions and 122 deletions
|
|
@ -119,6 +119,74 @@ def openai_reasoning_model_handler(payload):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def get_headers_and_cookies(
|
||||||
|
request: Request,
|
||||||
|
url,
|
||||||
|
key=None,
|
||||||
|
config=None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
):
|
||||||
|
cookies = {}
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"HTTP-Referer": "https://openwebui.com/",
|
||||||
|
"X-Title": "Open WebUI",
|
||||||
|
}
|
||||||
|
if "openrouter.ai" in url
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
**(
|
||||||
|
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
||||||
|
if metadata and metadata.get("chat_id")
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
token = None
|
||||||
|
auth_type = config.get("auth_type")
|
||||||
|
|
||||||
|
if auth_type == "bearer" or auth_type is None:
|
||||||
|
# Default to bearer if not specified
|
||||||
|
token = f"{key}"
|
||||||
|
elif auth_type == "none":
|
||||||
|
token = None
|
||||||
|
elif auth_type == "session":
|
||||||
|
cookies = request.cookies
|
||||||
|
token = request.state.token.credentials
|
||||||
|
elif auth_type == "oauth":
|
||||||
|
cookies = request.cookies
|
||||||
|
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id,
|
||||||
|
request.cookies.get("oauth_session_id", None),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
|
if oauth_token:
|
||||||
|
token = f"{oauth_token.get('access_token', '')}"
|
||||||
|
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
return headers, cookies
|
||||||
|
|
||||||
|
|
||||||
##########################################
|
##########################################
|
||||||
#
|
#
|
||||||
# API routes
|
# API routes
|
||||||
|
|
@ -210,34 +278,23 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
return FileResponse(file_path)
|
return FileResponse(file_path)
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||||
|
str(idx),
|
||||||
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||||
|
)
|
||||||
|
|
||||||
|
headers, cookies = get_headers_and_cookies(
|
||||||
|
request, url, key, api_config, user=user
|
||||||
|
)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
url=f"{url}/audio/speech",
|
url=f"{url}/audio/speech",
|
||||||
data=body,
|
data=body,
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
cookies=cookies,
|
||||||
"Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"HTTP-Referer": "https://openwebui.com/",
|
|
||||||
"X-Title": "Open WebUI",
|
|
||||||
}
|
|
||||||
if "openrouter.ai" in url
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -401,7 +458,10 @@ async def get_filtered_models(models, user):
|
||||||
return filtered_models
|
return filtered_models
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=MODELS_CACHE_TTL, key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models")
|
@cached(
|
||||||
|
ttl=MODELS_CACHE_TTL,
|
||||||
|
key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models",
|
||||||
|
)
|
||||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
|
|
||||||
|
|
@ -489,19 +549,9 @@ async def get_models(
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers, cookies = get_headers_and_cookies(
|
||||||
"Content-Type": "application/json",
|
request, url, key, api_config, user=user
|
||||||
**(
|
)
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if api_config.get("azure", False):
|
if api_config.get("azure", False):
|
||||||
models = {
|
models = {
|
||||||
|
|
@ -509,11 +559,10 @@ async def get_models(
|
||||||
"object": "list",
|
"object": "list",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
headers["Authorization"] = f"Bearer {key}"
|
|
||||||
|
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{url}/models",
|
f"{url}/models",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
|
|
@ -572,7 +621,9 @@ class ConnectionVerificationForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/verify")
|
@router.post("/verify")
|
||||||
async def verify_connection(
|
async def verify_connection(
|
||||||
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
request: Request,
|
||||||
|
form_data: ConnectionVerificationForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
url = form_data.url
|
url = form_data.url
|
||||||
key = form_data.key
|
key = form_data.key
|
||||||
|
|
@ -584,19 +635,9 @@ async def verify_connection(
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers, cookies = get_headers_and_cookies(
|
||||||
"Content-Type": "application/json",
|
request, url, key, api_config, user=user
|
||||||
**(
|
)
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if api_config.get("azure", False):
|
if api_config.get("azure", False):
|
||||||
headers["api-key"] = key
|
headers["api-key"] = key
|
||||||
|
|
@ -605,6 +646,7 @@ async def verify_connection(
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url=f"{url}/openai/models?api-version={api_version}",
|
url=f"{url}/openai/models?api-version={api_version}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
try:
|
try:
|
||||||
|
|
@ -624,11 +666,10 @@ async def verify_connection(
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
else:
|
else:
|
||||||
headers["Authorization"] = f"Bearer {key}"
|
|
||||||
|
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{url}/models",
|
f"{url}/models",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
try:
|
try:
|
||||||
|
|
@ -836,32 +877,9 @@ async def generate_chat_completion(
|
||||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = {
|
headers, cookies = get_headers_and_cookies(
|
||||||
"Content-Type": "application/json",
|
request, url, key, api_config, metadata, user=user
|
||||||
**(
|
)
|
||||||
{
|
|
||||||
"HTTP-Referer": "https://openwebui.com/",
|
|
||||||
"X-Title": "Open WebUI",
|
|
||||||
}
|
|
||||||
if "openrouter.ai" in url
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
**(
|
|
||||||
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
|
||||||
if metadata and metadata.get("chat_id")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if api_config.get("azure", False):
|
if api_config.get("azure", False):
|
||||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||||
|
|
@ -871,7 +889,6 @@ async def generate_chat_completion(
|
||||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||||
else:
|
else:
|
||||||
request_url = f"{url}/chat/completions"
|
request_url = f"{url}/chat/completions"
|
||||||
headers["Authorization"] = f"Bearer {key}"
|
|
||||||
|
|
||||||
payload = json.dumps(payload)
|
payload = json.dumps(payload)
|
||||||
|
|
||||||
|
|
@ -890,6 +907,7 @@ async def generate_chat_completion(
|
||||||
url=request_url,
|
url=request_url,
|
||||||
data=payload,
|
data=payload,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -951,31 +969,27 @@ async def embeddings(request: Request, form_data: dict, user):
|
||||||
models = request.app.state.OPENAI_MODELS
|
models = request.app.state.OPENAI_MODELS
|
||||||
if model_id in models:
|
if model_id in models:
|
||||||
idx = models[model_id]["urlIdx"]
|
idx = models[model_id]["urlIdx"]
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||||
|
str(idx),
|
||||||
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||||
|
)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
session = None
|
session = None
|
||||||
streaming = False
|
streaming = False
|
||||||
|
|
||||||
|
headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user)
|
||||||
try:
|
try:
|
||||||
session = aiohttp.ClientSession(trust_env=True)
|
session = aiohttp.ClientSession(trust_env=True)
|
||||||
r = await session.request(
|
r = await session.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/embeddings",
|
url=f"{url}/embeddings",
|
||||||
data=body,
|
data=body,
|
||||||
headers={
|
headers=headers,
|
||||||
"Authorization": f"Bearer {key}",
|
cookies=cookies,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||||
|
|
@ -1037,19 +1051,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
streaming = False
|
streaming = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers, cookies = get_headers_and_cookies(
|
||||||
"Content-Type": "application/json",
|
request, url, key, api_config, user=user
|
||||||
**(
|
)
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if api_config.get("azure", False):
|
if api_config.get("azure", False):
|
||||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||||
|
|
@ -1062,7 +1066,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
request_url = f"{url}/{path}?api-version={api_version}"
|
request_url = f"{url}/{path}?api-version={api_version}"
|
||||||
else:
|
else:
|
||||||
headers["Authorization"] = f"Bearer {key}"
|
|
||||||
request_url = f"{url}/{path}"
|
request_url = f"{url}/{path}"
|
||||||
|
|
||||||
session = aiohttp.ClientSession(trust_env=True)
|
session = aiohttp.ClientSession(trust_env=True)
|
||||||
|
|
@ -1071,6 +1074,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
url=request_url,
|
url=request_url,
|
||||||
data=body,
|
data=body,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -138,12 +138,10 @@ async def get_tools(
|
||||||
elif auth_type == "oauth":
|
elif auth_type == "oauth":
|
||||||
cookies = request.cookies
|
cookies = request.cookies
|
||||||
oauth_token = extra_params.get("__oauth_token__", None)
|
oauth_token = extra_params.get("__oauth_token__", None)
|
||||||
|
if oauth_token:
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
elif auth_type == "request_headers":
|
|
||||||
cookies = request.cookies
|
|
||||||
headers.update(dict(request.headers))
|
|
||||||
|
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
|
@ -564,9 +562,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def get_tool_servers_data(
|
async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
servers: List[Dict[str, Any]], session_token: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
# Prepare list of enabled servers along with their original index
|
# Prepare list of enabled servers along with their original index
|
||||||
server_entries = []
|
server_entries = []
|
||||||
for idx, server in enumerate(servers):
|
for idx, server in enumerate(servers):
|
||||||
|
|
@ -582,8 +578,9 @@ async def get_tool_servers_data(
|
||||||
|
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
token = server.get("key", "")
|
token = server.get("key", "")
|
||||||
elif auth_type == "session":
|
elif auth_type == "none":
|
||||||
token = session_token
|
# No authentication
|
||||||
|
pass
|
||||||
|
|
||||||
id = info.get("id")
|
id = info.get("id")
|
||||||
if not id:
|
if not id:
|
||||||
|
|
|
||||||
|
|
@ -443,8 +443,6 @@
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
<hr class=" border-gray-50 dark:border-gray-850 my-2.5 w-full" />
|
|
||||||
|
|
||||||
<div class="flex flex-col w-full">
|
<div class="flex flex-col w-full">
|
||||||
<div class="mb-1 flex justify-between">
|
<div class="mb-1 flex justify-between">
|
||||||
<div
|
<div
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue