mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
enh: custom headers for external tool servers
This commit is contained in:
parent
0bf686396d
commit
da42850eff
5 changed files with 108 additions and 13 deletions
|
|
@ -144,6 +144,7 @@ class ToolServerConnection(BaseModel):
|
|||
path: str
|
||||
type: Optional[str] = "openapi" # openapi, mcp
|
||||
auth_type: Optional[str]
|
||||
headers: Optional[dict]
|
||||
key: Optional[str]
|
||||
config: Optional[dict]
|
||||
|
||||
|
|
@ -282,10 +283,14 @@ async def verify_tool_servers_config(
|
|||
token = oauth_token.get("access_token", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if token:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
if form_data.headers:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(form_data.headers)
|
||||
|
||||
await client.connect(form_data.url, headers=headers)
|
||||
specs = await client.list_tool_specs()
|
||||
return {
|
||||
|
|
@ -303,6 +308,7 @@ async def verify_tool_servers_config(
|
|||
await client.disconnect()
|
||||
else: # openapi
|
||||
token = None
|
||||
headers = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
|
|
@ -323,8 +329,16 @@ async def verify_tool_servers_config(
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if token:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
if form_data.headers:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(form_data.headers)
|
||||
|
||||
url = get_tool_server_url(form_data.url, form_data.path)
|
||||
return await get_tool_server_data(token, url)
|
||||
return await get_tool_server_data(url, headers=headers)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -312,7 +312,11 @@ async def chat_completion_tools_handler(
|
|||
for message in recent_messages
|
||||
)
|
||||
|
||||
prompt = f"History:\n{chat_history}\nQuery: {user_message}" if chat_history else f"Query: {user_message}"
|
||||
prompt = (
|
||||
f"History:\n{chat_history}\nQuery: {user_message}"
|
||||
if chat_history
|
||||
else f"Query: {user_message}"
|
||||
)
|
||||
|
||||
return {
|
||||
"model": task_model_id,
|
||||
|
|
@ -1327,7 +1331,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
continue
|
||||
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
|
||||
headers = {}
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = (
|
||||
|
|
@ -1363,6 +1366,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
log.error(f"Error getting OAuth token: {e}")
|
||||
oauth_token = None
|
||||
|
||||
connection_headers = mcp_server_connection.get("headers", None)
|
||||
if connection_headers:
|
||||
for key, value in connection_headers.items():
|
||||
headers[key] = value
|
||||
|
||||
mcp_clients[server_id] = MCPClient()
|
||||
await mcp_clients[server_id].connect(
|
||||
url=mcp_server_connection.get("url", ""),
|
||||
|
|
|
|||
|
|
@ -155,7 +155,9 @@ async def get_tools(
|
|||
auth_type = tool_server_connection.get("auth_type", "bearer")
|
||||
|
||||
cookies = {}
|
||||
headers = {}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = (
|
||||
|
|
@ -177,7 +179,10 @@ async def get_tools(
|
|||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
)
|
||||
|
||||
headers["Content-Type"] = "application/json"
|
||||
connection_headers = tool_server_connection.get("headers", None)
|
||||
if connection_headers:
|
||||
for key, value in connection_headers.items():
|
||||
headers[key] = value
|
||||
|
||||
def make_tool_function(
|
||||
function_name, tool_server_data, headers
|
||||
|
|
@ -561,20 +566,21 @@ async def get_tool_servers(request: Request):
|
|||
return tool_servers
|
||||
|
||||
|
||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
headers = {
|
||||
async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
|
||||
_headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
if headers:
|
||||
_headers.update(headers)
|
||||
|
||||
error = None
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
||||
url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_body = await response.json()
|
||||
|
|
@ -644,7 +650,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
|
|||
openapi_path = server.get("path", "openapi.json")
|
||||
spec_url = get_tool_server_url(server_url, openapi_path)
|
||||
# Fetch from URL
|
||||
task = get_tool_server_data(token, spec_url)
|
||||
task = get_tool_server_data(
|
||||
spec_url,
|
||||
{"Authorization": f"Bearer {token}"} if token else None,
|
||||
)
|
||||
elif spec_type == "json" and server.get("spec", ""):
|
||||
# Use provided JSON spec
|
||||
spec_json = None
|
||||
|
|
|
|||
|
|
@ -426,7 +426,7 @@
|
|||
<div class="flex-1">
|
||||
<Tooltip
|
||||
content={$i18n.t(
|
||||
'Enter additional headers in JSON format (e.g. {{\'{{"X-Custom-Header": "value"}}\'}})'
|
||||
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
|
||||
)}
|
||||
>
|
||||
<Textarea
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
import AccessControl from './workspace/common/AccessControl.svelte';
|
||||
import Spinner from '$lib/components/common/Spinner.svelte';
|
||||
import XMark from '$lib/components/icons/XMark.svelte';
|
||||
import Textarea from './common/Textarea.svelte';
|
||||
|
||||
export let onSubmit: Function = () => {};
|
||||
export let onDelete: Function = () => {};
|
||||
|
|
@ -44,6 +45,7 @@
|
|||
|
||||
let auth_type = 'bearer';
|
||||
let key = '';
|
||||
let headers = '';
|
||||
|
||||
let accessControl = {};
|
||||
|
||||
|
|
@ -110,6 +112,20 @@
|
|||
}
|
||||
}
|
||||
|
||||
if (headers) {
|
||||
try {
|
||||
let _headers = JSON.parse(headers);
|
||||
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
|
||||
_headers = null;
|
||||
throw new Error('Headers must be a valid JSON object');
|
||||
}
|
||||
headers = JSON.stringify(_headers, null, 2);
|
||||
} catch (error) {
|
||||
toast.error($i18n.t('Headers must be a valid JSON object'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (direct) {
|
||||
const res = await getToolServerData(
|
||||
auth_type === 'bearer' ? key : localStorage.token,
|
||||
|
|
@ -128,6 +144,7 @@
|
|||
path,
|
||||
type,
|
||||
auth_type,
|
||||
headers: headers ? JSON.parse(headers) : undefined,
|
||||
key,
|
||||
config: {
|
||||
enable: enable,
|
||||
|
|
@ -177,6 +194,7 @@
|
|||
if (data.path) path = data.path;
|
||||
|
||||
if (data.auth_type) auth_type = data.auth_type;
|
||||
if (data.headers) headers = JSON.stringify(data.headers, null, 2);
|
||||
if (data.key) key = data.key;
|
||||
|
||||
if (data.info) {
|
||||
|
|
@ -210,6 +228,7 @@
|
|||
path,
|
||||
|
||||
auth_type,
|
||||
headers: headers ? JSON.parse(headers) : undefined,
|
||||
key,
|
||||
|
||||
info: {
|
||||
|
|
@ -256,6 +275,19 @@
|
|||
}
|
||||
}
|
||||
|
||||
if (headers) {
|
||||
try {
|
||||
const _headers = JSON.parse(headers);
|
||||
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
|
||||
throw new Error('Headers must be a valid JSON object');
|
||||
}
|
||||
headers = JSON.stringify(_headers, null, 2);
|
||||
} catch (error) {
|
||||
toast.error($i18n.t('Headers must be a valid JSON object'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const connection = {
|
||||
type,
|
||||
url,
|
||||
|
|
@ -265,9 +297,12 @@
|
|||
path,
|
||||
|
||||
auth_type,
|
||||
headers: headers ? JSON.parse(headers) : undefined,
|
||||
|
||||
key,
|
||||
config: {
|
||||
enable: enable,
|
||||
|
||||
access_control: accessControl
|
||||
},
|
||||
info: {
|
||||
|
|
@ -313,6 +348,8 @@
|
|||
path = connection?.path ?? 'openapi.json';
|
||||
|
||||
auth_type = connection?.auth_type ?? 'bearer';
|
||||
headers = connection?.headers ? JSON.stringify(connection.headers, null, 2) : '';
|
||||
|
||||
key = connection?.key ?? '';
|
||||
|
||||
id = connection.info?.id ?? '';
|
||||
|
|
@ -657,6 +694,33 @@
|
|||
</div>
|
||||
|
||||
{#if !direct}
|
||||
<div class="flex gap-2 mt-2">
|
||||
<div class="flex flex-col w-full">
|
||||
<label
|
||||
for="headers-input"
|
||||
class={`mb-0.5 text-xs text-gray-500
|
||||
${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : ''}`}
|
||||
>{$i18n.t('Headers')}</label
|
||||
>
|
||||
|
||||
<div class="flex-1">
|
||||
<Tooltip
|
||||
content={$i18n.t(
|
||||
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
|
||||
)}
|
||||
>
|
||||
<Textarea
|
||||
className="w-full text-sm outline-hidden"
|
||||
bind:value={headers}
|
||||
placeholder={$i18n.t('Enter additional headers in JSON format')}
|
||||
required={false}
|
||||
minSize={30}
|
||||
/>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
|
||||
|
||||
<div class="flex gap-2">
|
||||
|
|
|
|||
Loading…
Reference in a new issue