enh: custom headers for external tool servers

This commit is contained in:
Timothy Jaeryang Baek 2025-11-12 23:39:27 -05:00
parent 0bf686396d
commit da42850eff
5 changed files with 108 additions and 13 deletions

View file

@ -144,6 +144,7 @@ class ToolServerConnection(BaseModel):
path: str path: str
type: Optional[str] = "openapi" # openapi, mcp type: Optional[str] = "openapi" # openapi, mcp
auth_type: Optional[str] auth_type: Optional[str]
headers: Optional[dict]
key: Optional[str] key: Optional[str]
config: Optional[dict] config: Optional[dict]
@ -282,10 +283,14 @@ async def verify_tool_servers_config(
token = oauth_token.get("access_token", "") token = oauth_token.get("access_token", "")
except Exception as e: except Exception as e:
pass pass
if token: if token:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
if form_data.headers:
if headers is None:
headers = {}
headers.update(form_data.headers)
await client.connect(form_data.url, headers=headers) await client.connect(form_data.url, headers=headers)
specs = await client.list_tool_specs() specs = await client.list_tool_specs()
return { return {
@ -303,6 +308,7 @@ async def verify_tool_servers_config(
await client.disconnect() await client.disconnect()
else: # openapi else: # openapi
token = None token = None
headers = None
if form_data.auth_type == "bearer": if form_data.auth_type == "bearer":
token = form_data.key token = form_data.key
elif form_data.auth_type == "session": elif form_data.auth_type == "session":
@ -323,8 +329,16 @@ async def verify_tool_servers_config(
except Exception as e: except Exception as e:
pass pass
if token:
headers = {"Authorization": f"Bearer {token}"}
if form_data.headers:
if headers is None:
headers = {}
headers.update(form_data.headers)
url = get_tool_server_url(form_data.url, form_data.path) url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url) return await get_tool_server_data(url, headers=headers)
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:

View file

@ -312,7 +312,11 @@ async def chat_completion_tools_handler(
for message in recent_messages for message in recent_messages
) )
prompt = f"History:\n{chat_history}\nQuery: {user_message}" if chat_history else f"Query: {user_message}" prompt = (
f"History:\n{chat_history}\nQuery: {user_message}"
if chat_history
else f"Query: {user_message}"
)
return { return {
"model": task_model_id, "model": task_model_id,
@ -1327,7 +1331,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
continue continue
auth_type = mcp_server_connection.get("auth_type", "") auth_type = mcp_server_connection.get("auth_type", "")
headers = {} headers = {}
if auth_type == "bearer": if auth_type == "bearer":
headers["Authorization"] = ( headers["Authorization"] = (
@ -1363,6 +1366,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.error(f"Error getting OAuth token: {e}") log.error(f"Error getting OAuth token: {e}")
oauth_token = None oauth_token = None
connection_headers = mcp_server_connection.get("headers", None)
if connection_headers:
for key, value in connection_headers.items():
headers[key] = value
mcp_clients[server_id] = MCPClient() mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect( await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""), url=mcp_server_connection.get("url", ""),

View file

@ -155,7 +155,9 @@ async def get_tools(
auth_type = tool_server_connection.get("auth_type", "bearer") auth_type = tool_server_connection.get("auth_type", "bearer")
cookies = {} cookies = {}
headers = {} headers = {
"Content-Type": "application/json",
}
if auth_type == "bearer": if auth_type == "bearer":
headers["Authorization"] = ( headers["Authorization"] = (
@ -177,7 +179,10 @@ async def get_tools(
f"Bearer {oauth_token.get('access_token', '')}" f"Bearer {oauth_token.get('access_token', '')}"
) )
headers["Content-Type"] = "application/json" connection_headers = tool_server_connection.get("headers", None)
if connection_headers:
for key, value in connection_headers.items():
headers[key] = value
def make_tool_function( def make_tool_function(
function_name, tool_server_data, headers function_name, tool_server_data, headers
@ -561,20 +566,21 @@ async def get_tool_servers(request: Request):
return tool_servers return tool_servers
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
headers = { _headers = {
"Accept": "application/json", "Accept": "application/json",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
if token:
headers["Authorization"] = f"Bearer {token}" if headers:
_headers.update(headers)
error = None error = None
try: try:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
) as response: ) as response:
if response.status != 200: if response.status != 200:
error_body = await response.json() error_body = await response.json()
@ -644,7 +650,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
openapi_path = server.get("path", "openapi.json") openapi_path = server.get("path", "openapi.json")
spec_url = get_tool_server_url(server_url, openapi_path) spec_url = get_tool_server_url(server_url, openapi_path)
# Fetch from URL # Fetch from URL
task = get_tool_server_data(token, spec_url) task = get_tool_server_data(
spec_url,
{"Authorization": f"Bearer {token}"} if token else None,
)
elif spec_type == "json" and server.get("spec", ""): elif spec_type == "json" and server.get("spec", ""):
# Use provided JSON spec # Use provided JSON spec
spec_json = None spec_json = None

View file

@ -426,7 +426,7 @@
<div class="flex-1"> <div class="flex-1">
<Tooltip <Tooltip
content={$i18n.t( content={$i18n.t(
'Enter additional headers in JSON format (e.g. {{\'{{"X-Custom-Header": "value"}}\'}})' 'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
)} )}
> >
<Textarea <Textarea

View file

@ -22,6 +22,7 @@
import AccessControl from './workspace/common/AccessControl.svelte'; import AccessControl from './workspace/common/AccessControl.svelte';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import XMark from '$lib/components/icons/XMark.svelte'; import XMark from '$lib/components/icons/XMark.svelte';
import Textarea from './common/Textarea.svelte';
export let onSubmit: Function = () => {}; export let onSubmit: Function = () => {};
export let onDelete: Function = () => {}; export let onDelete: Function = () => {};
@ -44,6 +45,7 @@
let auth_type = 'bearer'; let auth_type = 'bearer';
let key = ''; let key = '';
let headers = '';
let accessControl = {}; let accessControl = {};
@ -110,6 +112,20 @@
} }
} }
if (headers) {
try {
let _headers = JSON.parse(headers);
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
_headers = null;
throw new Error('Headers must be a valid JSON object');
}
headers = JSON.stringify(_headers, null, 2);
} catch (error) {
toast.error($i18n.t('Headers must be a valid JSON object'));
return;
}
}
if (direct) { if (direct) {
const res = await getToolServerData( const res = await getToolServerData(
auth_type === 'bearer' ? key : localStorage.token, auth_type === 'bearer' ? key : localStorage.token,
@ -128,6 +144,7 @@
path, path,
type, type,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
config: { config: {
enable: enable, enable: enable,
@ -177,6 +194,7 @@
if (data.path) path = data.path; if (data.path) path = data.path;
if (data.auth_type) auth_type = data.auth_type; if (data.auth_type) auth_type = data.auth_type;
if (data.headers) headers = JSON.stringify(data.headers, null, 2);
if (data.key) key = data.key; if (data.key) key = data.key;
if (data.info) { if (data.info) {
@ -210,6 +228,7 @@
path, path,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
info: { info: {
@ -256,6 +275,19 @@
} }
} }
if (headers) {
try {
const _headers = JSON.parse(headers);
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
throw new Error('Headers must be a valid JSON object');
}
headers = JSON.stringify(_headers, null, 2);
} catch (error) {
toast.error($i18n.t('Headers must be a valid JSON object'));
return;
}
}
const connection = { const connection = {
type, type,
url, url,
@ -265,9 +297,12 @@
path, path,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
config: { config: {
enable: enable, enable: enable,
access_control: accessControl access_control: accessControl
}, },
info: { info: {
@ -313,6 +348,8 @@
path = connection?.path ?? 'openapi.json'; path = connection?.path ?? 'openapi.json';
auth_type = connection?.auth_type ?? 'bearer'; auth_type = connection?.auth_type ?? 'bearer';
headers = connection?.headers ? JSON.stringify(connection.headers, null, 2) : '';
key = connection?.key ?? ''; key = connection?.key ?? '';
id = connection.info?.id ?? ''; id = connection.info?.id ?? '';
@ -657,6 +694,33 @@
</div> </div>
{#if !direct} {#if !direct}
<div class="flex gap-2 mt-2">
<div class="flex flex-col w-full">
<label
for="headers-input"
class={`mb-0.5 text-xs text-gray-500
${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : ''}`}
>{$i18n.t('Headers')}</label
>
<div class="flex-1">
<Tooltip
content={$i18n.t(
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
)}
>
<Textarea
className="w-full text-sm outline-hidden"
bind:value={headers}
placeholder={$i18n.t('Enter additional headers in JSON format')}
required={false}
minSize={30}
/>
</Tooltip>
</div>
</div>
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex gap-2"> <div class="flex gap-2">