fix: oauth token

This commit is contained in:
Timothy Jaeryang Baek 2025-09-19 00:10:48 -05:00
parent a89ffccd7e
commit e4c4ba0979
4 changed files with 15 additions and 13 deletions

View file

@ -239,7 +239,7 @@ async def generate_function_chat_completion(
oauth_token = None oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
oauth_token = request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get("oauth_session_id", None),
) )

View file

@ -121,7 +121,7 @@ def openai_reasoning_model_handler(payload):
return payload return payload
def get_headers_and_cookies( async def get_headers_and_cookies(
request: Request, request: Request,
url, url,
key=None, key=None,
@ -174,7 +174,7 @@ def get_headers_and_cookies(
oauth_token = None oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
oauth_token = request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get("oauth_session_id", None),
) )
@ -305,7 +305,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
) )
headers, cookies = get_headers_and_cookies( headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user request, url, key, api_config, user=user
) )
@ -570,7 +570,7 @@ async def get_models(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers, cookies = get_headers_and_cookies( headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user request, url, key, api_config, user=user
) )
@ -656,7 +656,7 @@ async def verify_connection(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers, cookies = get_headers_and_cookies( headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user request, url, key, api_config, user=user
) )
@ -901,7 +901,7 @@ async def generate_chat_completion(
convert_logit_bias_input_to_json(payload["logit_bias"]) convert_logit_bias_input_to_json(payload["logit_bias"])
) )
headers, cookies = get_headers_and_cookies( headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, metadata, user=user request, url, key, api_config, metadata, user=user
) )
@ -1010,7 +1010,9 @@ async def embeddings(request: Request, form_data: dict, user):
session = None session = None
streaming = False streaming = False
headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user) headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user
)
try: try:
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(trust_env=True)
r = await session.request( r = await session.request(
@ -1080,7 +1082,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
streaming = False streaming = False
try: try:
headers, cookies = get_headers_and_cookies( headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user request, url, key, api_config, user=user
) )

View file

@ -818,7 +818,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
oauth_token = None oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
oauth_token = request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get("oauth_session_id", None),
) )
@ -1498,7 +1498,7 @@ async def process_chat_response(
oauth_token = None oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
oauth_token = request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get("oauth_session_id", None),
) )

View file

@ -157,7 +157,7 @@ class OAuthManager:
) )
return None return None
def get_oauth_token( async def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False self, user_id: str, session_id: str, force_refresh: bool = False
): ):
""" """
@ -186,7 +186,7 @@ class OAuthManager:
log.debug( log.debug(
f"Token refresh needed for user {user_id}, provider {session.provider}" f"Token refresh needed for user {user_id}, provider {session.provider}"
) )
refreshed_token = self._refresh_token(session) refreshed_token = await self._refresh_token(session)
if refreshed_token: if refreshed_token:
return refreshed_token return refreshed_token
else: else: