mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
enh: custom headers for external tool servers
This commit is contained in:
parent
0bf686396d
commit
da42850eff
5 changed files with 108 additions and 13 deletions
|
|
@ -144,6 +144,7 @@ class ToolServerConnection(BaseModel):
|
||||||
path: str
|
path: str
|
||||||
type: Optional[str] = "openapi" # openapi, mcp
|
type: Optional[str] = "openapi" # openapi, mcp
|
||||||
auth_type: Optional[str]
|
auth_type: Optional[str]
|
||||||
|
headers: Optional[dict]
|
||||||
key: Optional[str]
|
key: Optional[str]
|
||||||
config: Optional[dict]
|
config: Optional[dict]
|
||||||
|
|
||||||
|
|
@ -282,10 +283,14 @@ async def verify_tool_servers_config(
|
||||||
token = oauth_token.get("access_token", "")
|
token = oauth_token.get("access_token", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
if form_data.headers:
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
headers.update(form_data.headers)
|
||||||
|
|
||||||
await client.connect(form_data.url, headers=headers)
|
await client.connect(form_data.url, headers=headers)
|
||||||
specs = await client.list_tool_specs()
|
specs = await client.list_tool_specs()
|
||||||
return {
|
return {
|
||||||
|
|
@ -303,6 +308,7 @@ async def verify_tool_servers_config(
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
else: # openapi
|
else: # openapi
|
||||||
token = None
|
token = None
|
||||||
|
headers = None
|
||||||
if form_data.auth_type == "bearer":
|
if form_data.auth_type == "bearer":
|
||||||
token = form_data.key
|
token = form_data.key
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
|
|
@ -323,8 +329,16 @@ async def verify_tool_servers_config(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if token:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
if form_data.headers:
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
headers.update(form_data.headers)
|
||||||
|
|
||||||
url = get_tool_server_url(form_data.url, form_data.path)
|
url = get_tool_server_url(form_data.url, form_data.path)
|
||||||
return await get_tool_server_data(token, url)
|
return await get_tool_server_data(url, headers=headers)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,11 @@ async def chat_completion_tools_handler(
|
||||||
for message in recent_messages
|
for message in recent_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"History:\n{chat_history}\nQuery: {user_message}" if chat_history else f"Query: {user_message}"
|
prompt = (
|
||||||
|
f"History:\n{chat_history}\nQuery: {user_message}"
|
||||||
|
if chat_history
|
||||||
|
else f"Query: {user_message}"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": task_model_id,
|
"model": task_model_id,
|
||||||
|
|
@ -1327,7 +1331,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
auth_type = mcp_server_connection.get("auth_type", "")
|
auth_type = mcp_server_connection.get("auth_type", "")
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
|
|
@ -1363,6 +1366,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
|
|
||||||
|
connection_headers = mcp_server_connection.get("headers", None)
|
||||||
|
if connection_headers:
|
||||||
|
for key, value in connection_headers.items():
|
||||||
|
headers[key] = value
|
||||||
|
|
||||||
mcp_clients[server_id] = MCPClient()
|
mcp_clients[server_id] = MCPClient()
|
||||||
await mcp_clients[server_id].connect(
|
await mcp_clients[server_id].connect(
|
||||||
url=mcp_server_connection.get("url", ""),
|
url=mcp_server_connection.get("url", ""),
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,9 @@ async def get_tools(
|
||||||
auth_type = tool_server_connection.get("auth_type", "bearer")
|
auth_type = tool_server_connection.get("auth_type", "bearer")
|
||||||
|
|
||||||
cookies = {}
|
cookies = {}
|
||||||
headers = {}
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
|
|
@ -177,7 +179,10 @@ async def get_tools(
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
headers["Content-Type"] = "application/json"
|
connection_headers = tool_server_connection.get("headers", None)
|
||||||
|
if connection_headers:
|
||||||
|
for key, value in connection_headers.items():
|
||||||
|
headers[key] = value
|
||||||
|
|
||||||
def make_tool_function(
|
def make_tool_function(
|
||||||
function_name, tool_server_data, headers
|
function_name, tool_server_data, headers
|
||||||
|
|
@ -561,20 +566,21 @@ async def get_tool_servers(request: Request):
|
||||||
return tool_servers
|
return tool_servers
|
||||||
|
|
||||||
|
|
||||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
|
||||||
headers = {
|
_headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
if token:
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
if headers:
|
||||||
|
_headers.update(headers)
|
||||||
|
|
||||||
error = None
|
error = None
|
||||||
try:
|
try:
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_body = await response.json()
|
error_body = await response.json()
|
||||||
|
|
@ -644,7 +650,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
|
||||||
openapi_path = server.get("path", "openapi.json")
|
openapi_path = server.get("path", "openapi.json")
|
||||||
spec_url = get_tool_server_url(server_url, openapi_path)
|
spec_url = get_tool_server_url(server_url, openapi_path)
|
||||||
# Fetch from URL
|
# Fetch from URL
|
||||||
task = get_tool_server_data(token, spec_url)
|
task = get_tool_server_data(
|
||||||
|
spec_url,
|
||||||
|
{"Authorization": f"Bearer {token}"} if token else None,
|
||||||
|
)
|
||||||
elif spec_type == "json" and server.get("spec", ""):
|
elif spec_type == "json" and server.get("spec", ""):
|
||||||
# Use provided JSON spec
|
# Use provided JSON spec
|
||||||
spec_json = None
|
spec_json = None
|
||||||
|
|
|
||||||
|
|
@ -426,7 +426,7 @@
|
||||||
<div class="flex-1">
|
<div class="flex-1">
|
||||||
<Tooltip
|
<Tooltip
|
||||||
content={$i18n.t(
|
content={$i18n.t(
|
||||||
'Enter additional headers in JSON format (e.g. {{\'{{"X-Custom-Header": "value"}}\'}})'
|
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Textarea
|
<Textarea
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
import AccessControl from './workspace/common/AccessControl.svelte';
|
import AccessControl from './workspace/common/AccessControl.svelte';
|
||||||
import Spinner from '$lib/components/common/Spinner.svelte';
|
import Spinner from '$lib/components/common/Spinner.svelte';
|
||||||
import XMark from '$lib/components/icons/XMark.svelte';
|
import XMark from '$lib/components/icons/XMark.svelte';
|
||||||
|
import Textarea from './common/Textarea.svelte';
|
||||||
|
|
||||||
export let onSubmit: Function = () => {};
|
export let onSubmit: Function = () => {};
|
||||||
export let onDelete: Function = () => {};
|
export let onDelete: Function = () => {};
|
||||||
|
|
@ -44,6 +45,7 @@
|
||||||
|
|
||||||
let auth_type = 'bearer';
|
let auth_type = 'bearer';
|
||||||
let key = '';
|
let key = '';
|
||||||
|
let headers = '';
|
||||||
|
|
||||||
let accessControl = {};
|
let accessControl = {};
|
||||||
|
|
||||||
|
|
@ -110,6 +112,20 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (headers) {
|
||||||
|
try {
|
||||||
|
let _headers = JSON.parse(headers);
|
||||||
|
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
|
||||||
|
_headers = null;
|
||||||
|
throw new Error('Headers must be a valid JSON object');
|
||||||
|
}
|
||||||
|
headers = JSON.stringify(_headers, null, 2);
|
||||||
|
} catch (error) {
|
||||||
|
toast.error($i18n.t('Headers must be a valid JSON object'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (direct) {
|
if (direct) {
|
||||||
const res = await getToolServerData(
|
const res = await getToolServerData(
|
||||||
auth_type === 'bearer' ? key : localStorage.token,
|
auth_type === 'bearer' ? key : localStorage.token,
|
||||||
|
|
@ -128,6 +144,7 @@
|
||||||
path,
|
path,
|
||||||
type,
|
type,
|
||||||
auth_type,
|
auth_type,
|
||||||
|
headers: headers ? JSON.parse(headers) : undefined,
|
||||||
key,
|
key,
|
||||||
config: {
|
config: {
|
||||||
enable: enable,
|
enable: enable,
|
||||||
|
|
@ -177,6 +194,7 @@
|
||||||
if (data.path) path = data.path;
|
if (data.path) path = data.path;
|
||||||
|
|
||||||
if (data.auth_type) auth_type = data.auth_type;
|
if (data.auth_type) auth_type = data.auth_type;
|
||||||
|
if (data.headers) headers = JSON.stringify(data.headers, null, 2);
|
||||||
if (data.key) key = data.key;
|
if (data.key) key = data.key;
|
||||||
|
|
||||||
if (data.info) {
|
if (data.info) {
|
||||||
|
|
@ -210,6 +228,7 @@
|
||||||
path,
|
path,
|
||||||
|
|
||||||
auth_type,
|
auth_type,
|
||||||
|
headers: headers ? JSON.parse(headers) : undefined,
|
||||||
key,
|
key,
|
||||||
|
|
||||||
info: {
|
info: {
|
||||||
|
|
@ -256,6 +275,19 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (headers) {
|
||||||
|
try {
|
||||||
|
const _headers = JSON.parse(headers);
|
||||||
|
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
|
||||||
|
throw new Error('Headers must be a valid JSON object');
|
||||||
|
}
|
||||||
|
headers = JSON.stringify(_headers, null, 2);
|
||||||
|
} catch (error) {
|
||||||
|
toast.error($i18n.t('Headers must be a valid JSON object'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const connection = {
|
const connection = {
|
||||||
type,
|
type,
|
||||||
url,
|
url,
|
||||||
|
|
@ -265,9 +297,12 @@
|
||||||
path,
|
path,
|
||||||
|
|
||||||
auth_type,
|
auth_type,
|
||||||
|
headers: headers ? JSON.parse(headers) : undefined,
|
||||||
|
|
||||||
key,
|
key,
|
||||||
config: {
|
config: {
|
||||||
enable: enable,
|
enable: enable,
|
||||||
|
|
||||||
access_control: accessControl
|
access_control: accessControl
|
||||||
},
|
},
|
||||||
info: {
|
info: {
|
||||||
|
|
@ -313,6 +348,8 @@
|
||||||
path = connection?.path ?? 'openapi.json';
|
path = connection?.path ?? 'openapi.json';
|
||||||
|
|
||||||
auth_type = connection?.auth_type ?? 'bearer';
|
auth_type = connection?.auth_type ?? 'bearer';
|
||||||
|
headers = connection?.headers ? JSON.stringify(connection.headers, null, 2) : '';
|
||||||
|
|
||||||
key = connection?.key ?? '';
|
key = connection?.key ?? '';
|
||||||
|
|
||||||
id = connection.info?.id ?? '';
|
id = connection.info?.id ?? '';
|
||||||
|
|
@ -657,6 +694,33 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if !direct}
|
{#if !direct}
|
||||||
|
<div class="flex gap-2 mt-2">
|
||||||
|
<div class="flex flex-col w-full">
|
||||||
|
<label
|
||||||
|
for="headers-input"
|
||||||
|
class={`mb-0.5 text-xs text-gray-500
|
||||||
|
${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : ''}`}
|
||||||
|
>{$i18n.t('Headers')}</label
|
||||||
|
>
|
||||||
|
|
||||||
|
<div class="flex-1">
|
||||||
|
<Tooltip
|
||||||
|
content={$i18n.t(
|
||||||
|
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Textarea
|
||||||
|
className="w-full text-sm outline-hidden"
|
||||||
|
bind:value={headers}
|
||||||
|
placeholder={$i18n.t('Enter additional headers in JSON format')}
|
||||||
|
required={false}
|
||||||
|
minSize={30}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
|
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
|
||||||
|
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue