mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-14 05:15:18 +00:00
Add x-Open-Webui headers for ollama + more for openai
This commit is contained in:
parent
e4d7d41df6
commit
6d62e71c34
5 changed files with 173 additions and 33 deletions
|
|
@ -858,7 +858,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
return filtered_models
|
return filtered_models
|
||||||
|
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
|
|
||||||
# Filter out filter pipelines
|
# Filter out filter pipelines
|
||||||
models = [
|
models = [
|
||||||
|
|
@ -898,7 +898,7 @@ async def chat_completion(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
model_item = form_data.pop("model_item", {})
|
model_item = form_data.pop("model_item", {})
|
||||||
tasks = form_data.pop("background_tasks", None)
|
tasks = form_data.pop("background_tasks", None)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,11 @@ from urllib.parse import urlparse
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
import requests
|
import requests
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
from open_webui.env import (
|
||||||
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
|
|
@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
async def send_get_request(url, key=None):
|
async def send_get_request(url, key=None, user: UserModel = None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -96,6 +115,7 @@ async def send_post_request(
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
content_type: Optional[str] = None,
|
content_type: Optional[str] = None,
|
||||||
|
user: UserModel = None
|
||||||
):
|
):
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
|
|
@ -110,6 +130,16 @@ async def send_post_request(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -191,7 +221,19 @@ async def verify_connection(
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{url}/api/version",
|
f"{url}/api/version",
|
||||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
detail = f"HTTP Error: {r.status}"
|
detail = f"HTTP Error: {r.status}"
|
||||||
|
|
@ -254,7 +296,7 @@ async def update_config(
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=3)
|
@cached(ttl=3)
|
||||||
async def get_all_models(request: Request):
|
async def get_all_models(request: Request, user: UserModel=None):
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||||
request_tasks = []
|
request_tasks = []
|
||||||
|
|
@ -262,7 +304,7 @@ async def get_all_models(request: Request):
|
||||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||||
):
|
):
|
||||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||||
else:
|
else:
|
||||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||||
str(idx),
|
str(idx),
|
||||||
|
|
@ -275,7 +317,7 @@ async def get_all_models(request: Request):
|
||||||
key = api_config.get("key", None)
|
key = api_config.get("key", None)
|
||||||
|
|
||||||
if enable:
|
if enable:
|
||||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
request_tasks.append(send_get_request(f"{url}/api/tags", key, user=user))
|
||||||
else:
|
else:
|
||||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||||
|
|
||||||
|
|
@ -360,7 +402,7 @@ async def get_ollama_tags(
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
|
|
@ -370,7 +412,19 @@ async def get_ollama_tags(
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="GET",
|
method="GET",
|
||||||
url=f"{url}/api/tags",
|
url=f"{url}/api/tags",
|
||||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -477,6 +531,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
|
||||||
url, {}
|
url, {}
|
||||||
), # Legacy support
|
), # Legacy support
|
||||||
).get("key", None),
|
).get("key", None),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||||
]
|
]
|
||||||
|
|
@ -509,6 +564,7 @@ async def pull_model(
|
||||||
url=f"{url}/api/pull",
|
url=f"{url}/api/pull",
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -527,7 +583,7 @@ async def push_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name in models:
|
if form_data.name in models:
|
||||||
|
|
@ -545,6 +601,7 @@ async def push_model(
|
||||||
url=f"{url}/api/push",
|
url=f"{url}/api/push",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -571,6 +628,7 @@ async def create_model(
|
||||||
url=f"{url}/api/create",
|
url=f"{url}/api/create",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -588,7 +646,7 @@ async def copy_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.source in models:
|
if form_data.source in models:
|
||||||
|
|
@ -609,6 +667,16 @@ async def copy_model(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -643,7 +711,7 @@ async def delete_model(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name in models:
|
if form_data.name in models:
|
||||||
|
|
@ -665,6 +733,16 @@ async def delete_model(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -693,7 +771,7 @@ async def delete_model(
|
||||||
async def show_model_info(
|
async def show_model_info(
|
||||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name not in models:
|
if form_data.name not in models:
|
||||||
|
|
@ -714,6 +792,16 @@ async def show_model_info(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -757,7 +845,7 @@ async def embed(
|
||||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -783,6 +871,16 @@ async def embed(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -826,7 +924,7 @@ async def embeddings(
|
||||||
log.info(f"generate_ollama_embeddings {form_data}")
|
log.info(f"generate_ollama_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -852,6 +950,16 @@ async def embeddings(
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
|
|
@ -901,7 +1009,7 @@ async def generate_completion(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
@ -931,6 +1039,7 @@ async def generate_completion(
|
||||||
url=f"{url}/api/generate",
|
url=f"{url}/api/generate",
|
||||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1047,6 +1156,7 @@ async def generate_chat_completion(
|
||||||
stream=form_data.stream,
|
stream=form_data.stream,
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
content_type="application/x-ndjson",
|
content_type="application/x-ndjson",
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1149,6 +1259,7 @@ async def generate_openai_completion(
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
stream=payload.get("stream", False),
|
stream=payload.get("stream", False),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1227,6 +1338,7 @@ async def generate_openai_chat_completion(
|
||||||
payload=json.dumps(payload),
|
payload=json.dumps(payload),
|
||||||
stream=payload.get("stream", False),
|
stream=payload.get("stream", False),
|
||||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1240,7 +1352,7 @@ async def get_openai_models(
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
model_list = await get_all_models(request)
|
model_list = await get_all_models(request, user=user)
|
||||||
models = [
|
models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from open_webui.env import (
|
||||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
)
|
)
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||||
|
|
@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
async def send_get_request(url, key=None):
|
async def send_get_request(url, key=None, user: UserModel=None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
url,
|
||||||
|
headers={
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -247,7 +261,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
async def get_all_models_responses(request: Request) -> list:
|
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -271,7 +285,9 @@ async def get_all_models_responses(request: Request) -> list:
|
||||||
):
|
):
|
||||||
request_tasks.append(
|
request_tasks.append(
|
||||||
send_get_request(
|
send_get_request(
|
||||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
f"{url}/models",
|
||||||
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -291,6 +307,7 @@ async def get_all_models_responses(request: Request) -> list:
|
||||||
send_get_request(
|
send_get_request(
|
||||||
f"{url}/models",
|
f"{url}/models",
|
||||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -352,13 +369,13 @@ async def get_filtered_models(models, user):
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=3)
|
@cached(ttl=3)
|
||||||
async def get_all_models(request: Request) -> dict[str, list]:
|
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
|
|
||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return {"data": []}
|
return {"data": []}
|
||||||
|
|
||||||
responses = await get_all_models_responses(request)
|
responses = await get_all_models_responses(request, user=user)
|
||||||
|
|
||||||
def extract_data(response):
|
def extract_data(response):
|
||||||
if response and "data" in response:
|
if response and "data" in response:
|
||||||
|
|
@ -418,7 +435,7 @@ async def get_models(
|
||||||
}
|
}
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models(request)
|
models = await get_all_models(request, user=user)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||||
|
|
@ -515,6 +532,16 @@ async def verify_connection(
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
|
|
@ -587,7 +614,7 @@ async def generate_chat_completion(
|
||||||
detail="Model not found",
|
detail="Model not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||||
if model:
|
if model:
|
||||||
idx = model["urlIdx"]
|
idx = model["urlIdx"]
|
||||||
|
|
|
||||||
|
|
@ -285,7 +285,7 @@ chat_completion = generate_chat_completion
|
||||||
|
|
||||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
models = {
|
models = {
|
||||||
|
|
@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||||
raise Exception(f"Action not found: {action_id}")
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
if not request.app.state.MODELS:
|
if not request.app.state.MODELS:
|
||||||
await get_all_models(request)
|
await get_all_models(request, user=user)
|
||||||
|
|
||||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
models = {
|
models = {
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from open_webui.config import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
|
@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
async def get_all_base_models(request: Request):
|
async def get_all_base_models(request: Request, user: UserModel=None):
|
||||||
function_models = []
|
function_models = []
|
||||||
openai_models = []
|
openai_models = []
|
||||||
ollama_models = []
|
ollama_models = []
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_OPENAI_API:
|
if request.app.state.config.ENABLE_OPENAI_API:
|
||||||
openai_models = await openai.get_all_models(request)
|
openai_models = await openai.get_all_models(request, user=user)
|
||||||
openai_models = openai_models["data"]
|
openai_models = openai_models["data"]
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||||
ollama_models = await ollama.get_all_models(request)
|
ollama_models = await ollama.get_all_models(request, user=user)
|
||||||
ollama_models = [
|
ollama_models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
|
|
@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
async def get_all_models(request):
|
async def get_all_models(request, user: UserModel=None):
|
||||||
models = await get_all_base_models(request)
|
models = await get_all_base_models(request, user=user)
|
||||||
|
|
||||||
# If there are no models, return an empty list
|
# If there are no models, return an empty list
|
||||||
if len(models) == 0:
|
if len(models) == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue