Add OAuth support for upstream MCP servers (#381)

Co-authored-by: samanhappy <samanhappy@gmail.com>
This commit is contained in:
Copilot
2025-10-26 16:09:34 +08:00
committed by GitHub
parent 7dbd6c386e
commit 26b26a5fb1
30 changed files with 3780 additions and 412 deletions

View File

@@ -10,7 +10,10 @@ import {
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ServerInfo, ServerConfig, Tool } from '../types/index.js';
import { loadSettings, expandEnvVars, replaceEnvVars, getNameSeparator } from '../config/index.js';
import config from '../config/index.js';
@@ -21,6 +24,8 @@ import { OpenAPIClient } from '../clients/openapi.js';
import { RequestContextService } from './requestContextService.js';
import { getDataService } from './services.js';
import { getServerDao, ServerConfigWithName } from '../dao/index.js';
import { initializeAllOAuthClients } from './oauthService.js';
import { createOAuthProvider } from './mcpOAuthProvider.js';
const servers: { [sessionId: string]: Server } = {};
@@ -59,6 +64,10 @@ const setupKeepAlive = (serverInfo: ServerInfo, serverConfig: ServerConfig): voi
};
export const initUpstreamServers = async (): Promise<void> => {
// Initialize OAuth clients for servers with dynamic registration
await initializeAllOAuthClients();
// Register all tools from upstream servers
await registerAllTools(true);
};
@@ -155,28 +164,48 @@ export const cleanupAllServers = (): void => {
};
// Helper function to create transport based on server configuration
const createTransportFromConfig = (name: string, conf: ServerConfig): any => {
export const createTransportFromConfig = async (name: string, conf: ServerConfig): Promise<any> => {
let transport;
if (conf.type === 'streamable-http') {
const options: any = {};
if (conf.headers && Object.keys(conf.headers).length > 0) {
const options: StreamableHTTPClientTransportOptions = {};
const headers = conf.headers ? replaceEnvVars(conf.headers) : {};
if (Object.keys(headers).length > 0) {
options.requestInit = {
headers: replaceEnvVars(conf.headers),
headers,
};
}
// Create OAuth provider if configured - SDK will handle authentication automatically
const authProvider = await createOAuthProvider(name, conf);
if (authProvider) {
options.authProvider = authProvider;
console.log(`OAuth provider configured for server: ${name}`);
}
transport = new StreamableHTTPClientTransport(new URL(conf.url || ''), options);
} else if (conf.url) {
// SSE transport
const options: any = {};
if (conf.headers && Object.keys(conf.headers).length > 0) {
const headers = conf.headers ? replaceEnvVars(conf.headers) : {};
if (Object.keys(headers).length > 0) {
options.eventSourceInit = {
headers: replaceEnvVars(conf.headers),
headers,
};
options.requestInit = {
headers: replaceEnvVars(conf.headers),
headers,
};
}
// Create OAuth provider if configured - SDK will handle authentication automatically
const authProvider = await createOAuthProvider(name, conf);
if (authProvider) {
options.authProvider = authProvider;
console.log(`OAuth provider configured for server: ${name}`);
}
transport = new SSEClientTransport(new URL(conf.url), options);
} else if (conf.command && conf.args) {
// Stdio transport
@@ -269,7 +298,7 @@ const callToolWithReconnect = async (
}
// Recreate transport using helper function
const newTransport = createTransportFromConfig(serverInfo.name, server);
const newTransport = await createTransportFromConfig(serverInfo.name, server);
// Create new client
const client = new Client(
@@ -345,59 +374,143 @@ export const initializeClientsFromSettings = async (
): Promise<ServerInfo[]> => {
const allServers: ServerConfigWithName[] = await serverDao.findAll();
const existingServerInfos = serverInfos;
serverInfos = [];
const nextServerInfos: ServerInfo[] = [];
for (const conf of allServers) {
const { name } = conf;
// Skip disabled servers
if (conf.enabled === false) {
console.log(`Skipping disabled server: ${name}`);
serverInfos.push({
name,
owner: conf.owner,
status: 'disconnected',
error: null,
tools: [],
prompts: [],
createTime: Date.now(),
enabled: false,
});
continue;
}
// Check if server is already connected
const existingServer = existingServerInfos.find(
(s) => s.name === name && s.status === 'connected',
);
if (existingServer && (!serverName || serverName !== name)) {
serverInfos.push({
...existingServer,
enabled: conf.enabled === undefined ? true : conf.enabled,
});
console.log(`Server '${name}' is already connected.`);
continue;
}
let transport;
let openApiClient;
if (conf.type === 'openapi') {
// Handle OpenAPI type servers
if (!conf.openapi?.url && !conf.openapi?.schema) {
console.warn(
`Skipping OpenAPI server '${name}': missing OpenAPI specification URL or schema`,
);
serverInfos.push({
try {
for (const conf of allServers) {
const { name } = conf;
// Skip disabled servers
if (conf.enabled === false) {
console.log(`Skipping disabled server: ${name}`);
nextServerInfos.push({
name,
owner: conf.owner,
status: 'disconnected',
error: 'Missing OpenAPI specification URL or schema',
error: null,
tools: [],
prompts: [],
createTime: Date.now(),
enabled: false,
});
continue;
}
// Check if server is already connected
const existingServer = existingServerInfos.find(
(s) => s.name === name && s.status === 'connected',
);
if (existingServer && (!serverName || serverName !== name)) {
nextServerInfos.push({
...existingServer,
enabled: conf.enabled === undefined ? true : conf.enabled,
});
console.log(`Server '${name}' is already connected.`);
continue;
}
let transport;
let openApiClient;
if (conf.type === 'openapi') {
// Handle OpenAPI type servers
if (!conf.openapi?.url && !conf.openapi?.schema) {
console.warn(
`Skipping OpenAPI server '${name}': missing OpenAPI specification URL or schema`,
);
nextServerInfos.push({
name,
owner: conf.owner,
status: 'disconnected',
error: 'Missing OpenAPI specification URL or schema',
tools: [],
prompts: [],
createTime: Date.now(),
});
continue;
}
// Create server info first and keep reference to it
const serverInfo: ServerInfo = {
name,
owner: conf.owner,
status: 'connecting',
error: null,
tools: [],
prompts: [],
createTime: Date.now(),
enabled: conf.enabled === undefined ? true : conf.enabled,
config: conf, // Store reference to original config for OpenAPI passthrough headers
};
nextServerInfos.push(serverInfo);
try {
// Create OpenAPI client instance
openApiClient = new OpenAPIClient(conf);
console.log(`Initializing OpenAPI server: ${name}...`);
// Perform async initialization
await openApiClient.initialize();
// Convert OpenAPI tools to MCP tool format
const openApiTools = openApiClient.getTools();
const mcpTools: Tool[] = openApiTools.map((tool) => ({
name: `${name}${getNameSeparator()}${tool.name}`,
description: tool.description,
inputSchema: cleanInputSchema(tool.inputSchema),
}));
// Update server info with successful initialization
serverInfo.status = 'connected';
serverInfo.tools = mcpTools;
serverInfo.openApiClient = openApiClient;
console.log(
`Successfully initialized OpenAPI server: ${name} with ${mcpTools.length} tools`,
);
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, mcpTools);
continue;
} catch (error) {
console.error(`Failed to initialize OpenAPI server ${name}:`, error);
// Update the already pushed server info with error status
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to initialize OpenAPI server: ${error}`;
continue;
}
} else {
transport = await createTransportFromConfig(name, conf);
}
const client = new Client(
{
name: `mcp-client-${name}`,
version: '1.0.0',
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
},
},
);
const initRequestOptions = isInit
? {
timeout: Number(config.initTimeout) || 60000,
}
: undefined;
// Get request options from server configuration, with fallbacks
const serverRequestOptions = conf.options || {};
const requestOptions = {
timeout: serverRequestOptions.timeout || 60000,
resetTimeoutOnProgress: serverRequestOptions.resetTimeoutOnProgress || false,
maxTotalTimeout: serverRequestOptions.maxTotalTimeout,
};
// Create server info first and keep reference to it
const serverInfo: ServerInfo = {
name,
@@ -406,169 +519,122 @@ export const initializeClientsFromSettings = async (
error: null,
tools: [],
prompts: [],
client,
transport,
options: requestOptions,
createTime: Date.now(),
enabled: conf.enabled === undefined ? true : conf.enabled,
config: conf, // Store reference to original config for OpenAPI passthrough headers
config: conf, // Store reference to original config
};
serverInfos.push(serverInfo);
try {
// Create OpenAPI client instance
openApiClient = new OpenAPIClient(conf);
console.log(`Initializing OpenAPI server: ${name}...`);
// Perform async initialization
await openApiClient.initialize();
// Convert OpenAPI tools to MCP tool format
const openApiTools = openApiClient.getTools();
const mcpTools: Tool[] = openApiTools.map((tool) => ({
name: `${name}${getNameSeparator()}${tool.name}`,
description: tool.description,
inputSchema: cleanInputSchema(tool.inputSchema),
}));
// Update server info with successful initialization
serverInfo.status = 'connected';
serverInfo.tools = mcpTools;
serverInfo.openApiClient = openApiClient;
console.log(
`Successfully initialized OpenAPI server: ${name} with ${mcpTools.length} tools`,
);
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, mcpTools);
continue;
} catch (error) {
console.error(`Failed to initialize OpenAPI server ${name}:`, error);
// Update the already pushed server info with error status
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to initialize OpenAPI server: ${error}`;
continue;
const pendingAuth = conf.oauth?.pendingAuthorization;
if (pendingAuth) {
serverInfo.status = 'oauth_required';
serverInfo.error = null;
serverInfo.oauth = {
authorizationUrl: pendingAuth.authorizationUrl,
state: pendingAuth.state,
codeVerifier: pendingAuth.codeVerifier,
};
}
} else {
transport = createTransportFromConfig(name, conf);
nextServerInfos.push(serverInfo);
client
.connect(transport, initRequestOptions || requestOptions)
.then(() => {
console.log(`Successfully connected client for server: ${name}`);
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
let dataError: Error | null = null;
if (capabilities?.tools) {
client
.listTools({}, initRequestOptions || requestOptions)
.then((tools) => {
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
serverInfo.tools = tools.tools.map((tool) => ({
name: `${name}${getNameSeparator()}${tool.name}`,
description: tool.description || '',
inputSchema: cleanInputSchema(tool.inputSchema || {}),
}));
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
})
.catch((error) => {
console.error(
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
if (capabilities?.prompts) {
client
.listPrompts({}, initRequestOptions || requestOptions)
.then((prompts) => {
console.log(
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
);
serverInfo.prompts = prompts.prompts.map((prompt) => ({
name: `${name}${getNameSeparator()}${prompt.name}`,
title: prompt.title,
description: prompt.description,
arguments: prompt.arguments,
}));
})
.catch((error) => {
console.error(
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
if (!dataError) {
serverInfo.status = 'connected';
serverInfo.error = null;
// Set up keep-alive ping for SSE connections
setupKeepAlive(serverInfo, conf);
} else {
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to list data: ${dataError} `;
}
})
.catch(async (error) => {
// Check if this is an OAuth authorization error
const isOAuthError =
error?.message?.includes('OAuth authorization required') ||
error?.message?.includes('Authorization required');
if (isOAuthError) {
// OAuth provider should have already set the status to 'oauth_required'
// and stored the authorization URL in serverInfo.oauth
console.log(
`OAuth authorization required for server ${name}. Status should be set to 'oauth_required'.`,
);
// Make sure status is set correctly
if (serverInfo.status !== 'oauth_required') {
serverInfo.status = 'oauth_required';
}
serverInfo.error = null;
} else {
console.error(
`Failed to connect client for server ${name} by error: ${error} with stack: ${error.stack}`,
);
// Other connection errors
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to connect: ${error.stack} `;
}
});
console.log(`Initialized client for server: ${name}`);
}
const client = new Client(
{
name: `mcp-client-${name}`,
version: '1.0.0',
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
},
},
);
const initRequestOptions = isInit
? {
timeout: Number(config.initTimeout) || 60000,
}
: undefined;
// Get request options from server configuration, with fallbacks
const serverRequestOptions = conf.options || {};
const requestOptions = {
timeout: serverRequestOptions.timeout || 60000,
resetTimeoutOnProgress: serverRequestOptions.resetTimeoutOnProgress || false,
maxTotalTimeout: serverRequestOptions.maxTotalTimeout,
};
// Create server info first and keep reference to it
const serverInfo: ServerInfo = {
name,
owner: conf.owner,
status: 'connecting',
error: null,
tools: [],
prompts: [],
client,
transport,
options: requestOptions,
createTime: Date.now(),
config: conf, // Store reference to original config
};
serverInfos.push(serverInfo);
client
.connect(transport, initRequestOptions || requestOptions)
.then(() => {
console.log(`Successfully connected client for server: ${name}`);
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
let dataError: Error | null = null;
if (capabilities?.tools) {
client
.listTools({}, initRequestOptions || requestOptions)
.then((tools) => {
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
serverInfo.tools = tools.tools.map((tool) => ({
name: `${name}${getNameSeparator()}${tool.name}`,
description: tool.description || '',
inputSchema: cleanInputSchema(tool.inputSchema || {}),
}));
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
})
.catch((error) => {
console.error(
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
if (capabilities?.prompts) {
client
.listPrompts({}, initRequestOptions || requestOptions)
.then((prompts) => {
console.log(
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
);
serverInfo.prompts = prompts.prompts.map((prompt) => ({
name: `${name}${getNameSeparator()}${prompt.name}`,
title: prompt.title,
description: prompt.description,
arguments: prompt.arguments,
}));
})
.catch((error) => {
console.error(
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
if (!dataError) {
serverInfo.status = 'connected';
serverInfo.error = null;
// Set up keep-alive ping for SSE connections
setupKeepAlive(serverInfo, conf);
} else {
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to list data: ${dataError} `;
}
})
.catch((error) => {
console.error(
`Failed to connect client for server ${name} by error: ${error} with stack: ${error.stack}`,
);
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to connect: ${error.stack} `;
});
console.log(`Initialized client for server: ${name}`);
} catch (error) {
// Restore previous state if initialization fails to avoid exposing an empty server list
serverInfos = existingServerInfos;
throw error;
}
serverInfos = nextServerInfos;
return serverInfos;
};
@@ -584,39 +650,48 @@ export const getServersInfo = async (): Promise<Omit<ServerInfo, 'client' | 'tra
const filterServerInfos: ServerInfo[] = dataService.filterData
? dataService.filterData(serverInfos)
: serverInfos;
const infos = filterServerInfos.map(({ name, status, tools, prompts, createTime, error }) => {
const serverConfig = allServers.find((server) => server.name === name);
const enabled = serverConfig ? serverConfig.enabled !== false : true;
const infos = filterServerInfos.map(
({ name, status, tools, prompts, createTime, error, oauth }) => {
const serverConfig = allServers.find((server) => server.name === name);
const enabled = serverConfig ? serverConfig.enabled !== false : true;
// Add enabled status and custom description to each tool
const toolsWithEnabled = tools.map((tool) => {
const toolConfig = serverConfig?.tools?.[tool.name];
return {
...tool,
description: toolConfig?.description || tool.description, // Use custom description if available
enabled: toolConfig?.enabled !== false, // Default to true if not explicitly disabled
};
});
const promptsWithEnabled = prompts.map((prompt) => {
const promptConfig = serverConfig?.prompts?.[prompt.name];
return {
...prompt,
description: promptConfig?.description || prompt.description, // Use custom description if available
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
};
});
// Add enabled status and custom description to each tool
const toolsWithEnabled = tools.map((tool) => {
const toolConfig = serverConfig?.tools?.[tool.name];
return {
...tool,
description: toolConfig?.description || tool.description, // Use custom description if available
enabled: toolConfig?.enabled !== false, // Default to true if not explicitly disabled
name,
status,
error,
tools: toolsWithEnabled,
prompts: promptsWithEnabled,
createTime,
enabled,
oauth: oauth
? {
authorizationUrl: oauth.authorizationUrl,
state: oauth.state,
// Don't expose codeVerifier to frontend for security
}
: undefined,
};
});
const promptsWithEnabled = prompts.map((prompt) => {
const promptConfig = serverConfig?.prompts?.[prompt.name];
return {
...prompt,
description: promptConfig?.description || prompt.description, // Use custom description if available
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
};
});
return {
name,
status,
error,
tools: toolsWithEnabled,
prompts: promptsWithEnabled,
createTime,
enabled,
};
});
},
);
infos.sort((a, b) => {
if (a.enabled === b.enabled) return 0;
return a.enabled ? -1 : 1;
@@ -629,6 +704,51 @@ export const getServerByName = (name: string): ServerInfo | undefined => {
return serverInfos.find((serverInfo) => serverInfo.name === name);
};
// Get server by OAuth state parameter
export const getServerByOAuthState = (state: string): ServerInfo | undefined => {
return serverInfos.find((serverInfo) => serverInfo.oauth?.state === state);
};
/**
* Reconnect a server after OAuth authorization or configuration change
* This will close the existing connection and reinitialize the server
*/
export const reconnectServer = async (serverName: string): Promise<void> => {
console.log(`Reconnecting server: ${serverName}`);
const serverInfo = getServerByName(serverName);
if (!serverInfo) {
throw new Error(`Server not found: ${serverName}`);
}
// Close existing connection if any
if (serverInfo.client) {
try {
serverInfo.client.close();
} catch (error) {
console.warn(`Error closing client for server ${serverName}:`, error);
}
}
if (serverInfo.transport) {
try {
serverInfo.transport.close();
} catch (error) {
console.warn(`Error closing transport for server ${serverName}:`, error);
}
}
if (serverInfo.keepAliveIntervalId) {
clearInterval(serverInfo.keepAliveIntervalId);
serverInfo.keepAliveIntervalId = undefined;
}
// Reinitialize the server
await initializeClientsFromSettings(false, serverName);
console.log(`Successfully reconnected server: ${serverName}`);
};
// Filter tools by server configuration
const filterToolsByConfig = async (serverName: string, tools: Tool[]): Promise<Tool[]> => {
const serverConfig = await serverDao.findById(serverName);