mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 13:55:19 +00:00
refac: oauth pass client auth params
This commit is contained in:
parent
6b638db114
commit
6d9a562edd
1 changed files with 18 additions and 6 deletions
|
|
@ -333,7 +333,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
|
|
||||||
# The mcp package requires optional unset values to be None. If an empty string is passed, it gets validated and fails.
|
# The mcp package requires optional unset values to be None. If an empty string is passed, it gets validated and fails.
|
||||||
# This replaces all empty strings with None.
|
# This replaces all empty strings with None.
|
||||||
registration_response_json = {k: (None if v == "" else v) for k, v in registration_response_json.items()}
|
registration_response_json = {
|
||||||
|
k: (None if v == "" else v)
|
||||||
|
for k, v in registration_response_json.items()
|
||||||
|
}
|
||||||
oauth_client_info = OAuthClientInformationFull.model_validate(
|
oauth_client_info = OAuthClientInformationFull.model_validate(
|
||||||
{
|
{
|
||||||
**registration_response_json,
|
**registration_response_json,
|
||||||
|
|
@ -694,16 +697,17 @@ class OAuthClientManager:
|
||||||
error_message = None
|
error_message = None
|
||||||
try:
|
try:
|
||||||
client_info = self.get_client_info(client_id)
|
client_info = self.get_client_info(client_id)
|
||||||
token_params = {}
|
|
||||||
|
auth_params = {}
|
||||||
if (
|
if (
|
||||||
client_info
|
client_info
|
||||||
and hasattr(client_info, "client_id")
|
and hasattr(client_info, "client_id")
|
||||||
and hasattr(client_info, "client_secret")
|
and hasattr(client_info, "client_secret")
|
||||||
):
|
):
|
||||||
token_params["client_id"] = client_info.client_id
|
auth_params["client_id"] = client_info.client_id
|
||||||
token_params["client_secret"] = client_info.client_secret
|
auth_params["client_secret"] = client_info.client_secret
|
||||||
|
|
||||||
token = await client.authorize_access_token(request, **token_params)
|
token = await client.authorize_access_token(request, **auth_params)
|
||||||
if token:
|
if token:
|
||||||
try:
|
try:
|
||||||
# Add timestamp for tracking
|
# Add timestamp for tracking
|
||||||
|
|
@ -1228,8 +1232,16 @@ class OAuthManager:
|
||||||
error_message = None
|
error_message = None
|
||||||
try:
|
try:
|
||||||
client = self.get_client(provider)
|
client = self.get_client(provider)
|
||||||
|
|
||||||
|
auth_params = {}
|
||||||
|
if client:
|
||||||
|
if hasattr(client, "client_id"):
|
||||||
|
auth_params["client_id"] = client.client_id
|
||||||
|
if hasattr(client, "client_secret"):
|
||||||
|
auth_params["client_secret"] = client.client_secret
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = await client.authorize_access_token(request)
|
token = await client.authorize_access_token(request, **auth_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
detailed_error = _build_oauth_callback_error_message(e)
|
detailed_error = _build_oauth_callback_error_message(e)
|
||||||
log.warning(
|
log.warning(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue