mirror of
https://github.com/samanhappy/mcphub.git
synced 2025-12-31 20:00:00 -05:00
Add OAuth support for upstream MCP servers (#381)
Co-authored-by: samanhappy <samanhappy@gmail.com>
This commit is contained in:
588
src/services/mcpOAuthProvider.ts
Normal file
588
src/services/mcpOAuthProvider.ts
Normal file
@@ -0,0 +1,588 @@
|
||||
/**
|
||||
* MCP OAuth Provider Implementation
|
||||
*
|
||||
* Implements OAuthClientProvider interface from @modelcontextprotocol/sdk/client/auth.js
|
||||
* to handle OAuth 2.0 authentication for upstream MCP servers using the SDK's built-in
|
||||
* OAuth support.
|
||||
*
|
||||
* This provider integrates with our existing OAuth infrastructure:
|
||||
* - Dynamic client registration (RFC7591)
|
||||
* - Token storage and refresh
|
||||
* - Authorization flow handling
|
||||
*/
|
||||
|
||||
import { randomBytes } from 'node:crypto';
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { ServerConfig } from '../types/index.js';
|
||||
import { loadSettings } from '../config/index.js';
|
||||
import {
|
||||
initializeOAuthForServer,
|
||||
getRegisteredClient,
|
||||
removeRegisteredClient,
|
||||
fetchScopesFromServer,
|
||||
} from './oauthClientRegistration.js';
|
||||
import {
|
||||
clearOAuthData,
|
||||
loadServerConfig,
|
||||
mutateOAuthSettings,
|
||||
persistClientCredentials,
|
||||
persistTokens,
|
||||
updatePendingAuthorization,
|
||||
ServerConfigWithOAuth,
|
||||
} from './oauthSettingsStore.js';
|
||||
|
||||
// Import getServerByName to access ServerInfo
|
||||
import { getServerByName } from './mcpService.js';
|
||||
|
||||
/**
|
||||
* MCPHub OAuth Provider for server-side OAuth flows
|
||||
*
|
||||
* This provider handles OAuth authentication for upstream MCP servers.
|
||||
* Unlike browser-based providers, this runs in a Node.js server environment,
|
||||
* so the authorization flow requires external handling (e.g., via web UI).
|
||||
*/
|
||||
export class MCPHubOAuthProvider implements OAuthClientProvider {
|
||||
private serverName: string;
|
||||
private serverConfig: ServerConfig;
|
||||
private _codeVerifier?: string;
|
||||
private _currentState?: string;
|
||||
|
||||
constructor(serverName: string, serverConfig: ServerConfig) {
|
||||
this.serverName = serverName;
|
||||
this.serverConfig = serverConfig;
|
||||
}
|
||||
|
||||
private getSystemInstallBaseUrl(): string | undefined {
|
||||
const settings = loadSettings();
|
||||
return settings.systemConfig?.install?.baseUrl;
|
||||
}
|
||||
|
||||
private sanitizeRedirectUri(input?: string): string | null {
|
||||
if (!input) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(input);
|
||||
url.searchParams.delete('server');
|
||||
const params = url.searchParams.toString();
|
||||
url.search = params ? `?${params}` : '';
|
||||
return url.toString();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private buildRedirectUriFromBase(baseUrl?: string): string | null {
|
||||
if (!baseUrl) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const normalizedBase = trimmed.endsWith('/') ? trimmed : `${trimmed}/`;
|
||||
const redirect = new URL('oauth/callback', normalizedBase);
|
||||
return this.sanitizeRedirectUri(redirect.toString());
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get redirect URL for OAuth callback
|
||||
*/
|
||||
get redirectUrl(): string {
|
||||
const dynamicConfig = this.serverConfig.oauth?.dynamicRegistration;
|
||||
const metadata = dynamicConfig?.metadata || {};
|
||||
const fallback = 'http://localhost:3000/oauth/callback';
|
||||
const systemConfigured = this.buildRedirectUriFromBase(this.getSystemInstallBaseUrl());
|
||||
const metadataConfigured = this.sanitizeRedirectUri(metadata.redirect_uris?.[0]);
|
||||
|
||||
return systemConfigured ?? metadataConfigured ?? fallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get client metadata for dynamic registration or static configuration
|
||||
*/
|
||||
get clientMetadata(): OAuthClientMetadata {
|
||||
const dynamicConfig = this.serverConfig.oauth?.dynamicRegistration;
|
||||
const metadata = dynamicConfig?.metadata || {};
|
||||
|
||||
// Use redirectUrl getter to ensure consistent callback URL
|
||||
const redirectUri = this.redirectUrl;
|
||||
const systemConfigured = this.buildRedirectUriFromBase(this.getSystemInstallBaseUrl());
|
||||
const metadataRedirects =
|
||||
metadata.redirect_uris && metadata.redirect_uris.length > 0
|
||||
? metadata.redirect_uris
|
||||
.map((uri) => this.sanitizeRedirectUri(uri))
|
||||
.filter((uri): uri is string => Boolean(uri))
|
||||
: [];
|
||||
const redirectUris: string[] = [];
|
||||
|
||||
if (systemConfigured) {
|
||||
redirectUris.push(systemConfigured);
|
||||
}
|
||||
|
||||
for (const uri of metadataRedirects) {
|
||||
if (!redirectUris.includes(uri)) {
|
||||
redirectUris.push(uri);
|
||||
}
|
||||
}
|
||||
|
||||
if (!redirectUris.includes(redirectUri)) {
|
||||
redirectUris.push(redirectUri);
|
||||
}
|
||||
|
||||
const tokenEndpointAuthMethod =
|
||||
metadata.token_endpoint_auth_method && metadata.token_endpoint_auth_method !== ''
|
||||
? metadata.token_endpoint_auth_method
|
||||
: this.serverConfig.oauth?.clientSecret
|
||||
? 'client_secret_post'
|
||||
: 'none';
|
||||
|
||||
return {
|
||||
...metadata, // Include any additional custom metadata
|
||||
client_name: metadata.client_name || `MCPHub - ${this.serverName}`,
|
||||
redirect_uris: redirectUris,
|
||||
grant_types: metadata.grant_types || ['authorization_code', 'refresh_token'],
|
||||
response_types: metadata.response_types || ['code'],
|
||||
token_endpoint_auth_method: tokenEndpointAuthMethod,
|
||||
scope: metadata.scope || this.serverConfig.oauth?.scopes?.join(' ') || 'openid',
|
||||
};
|
||||
}
|
||||
|
||||
private async ensureScopesFromServer(): Promise<string[] | undefined> {
|
||||
const serverUrl = this.serverConfig.url;
|
||||
const existingScopes = this.serverConfig.oauth?.scopes;
|
||||
|
||||
if (!serverUrl) {
|
||||
return existingScopes;
|
||||
}
|
||||
|
||||
if (existingScopes && existingScopes.length > 0) {
|
||||
return existingScopes;
|
||||
}
|
||||
|
||||
try {
|
||||
const scopes = await fetchScopesFromServer(serverUrl);
|
||||
if (scopes && scopes.length > 0) {
|
||||
const updatedConfig = await mutateOAuthSettings(this.serverName, ({ oauth }) => {
|
||||
oauth.scopes = scopes;
|
||||
});
|
||||
if (updatedConfig) {
|
||||
this.serverConfig = updatedConfig;
|
||||
}
|
||||
console.log(`Stored auto-detected scopes for ${this.serverName}: ${scopes.join(', ')}`);
|
||||
return scopes;
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to auto-detect scopes for ${this.serverName}: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
|
||||
return existingScopes;
|
||||
}
|
||||
|
||||
private generateState(): string {
|
||||
const payload = {
|
||||
server: this.serverName,
|
||||
nonce: randomBytes(16).toString('hex'),
|
||||
};
|
||||
const base64 = Buffer.from(JSON.stringify(payload)).toString('base64');
|
||||
return base64.replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
}
|
||||
|
||||
async state(): Promise<string> {
|
||||
if (!this._currentState) {
|
||||
this._currentState = this.generateState();
|
||||
}
|
||||
return this._currentState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get previously registered client information
|
||||
*/
|
||||
clientInformation(): OAuthClientInformation | undefined {
|
||||
const clientInfo = getRegisteredClient(this.serverName);
|
||||
|
||||
if (!clientInfo) {
|
||||
// Try to use static client configuration from cached serverConfig first
|
||||
let serverConfig = this.serverConfig;
|
||||
|
||||
// If cached config doesn't have clientId, reload from settings
|
||||
if (!serverConfig?.oauth?.clientId) {
|
||||
const storedConfig = loadServerConfig(this.serverName);
|
||||
|
||||
if (storedConfig) {
|
||||
this.serverConfig = storedConfig;
|
||||
serverConfig = storedConfig;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to use static client configuration from serverConfig
|
||||
if (serverConfig?.oauth?.clientId) {
|
||||
return {
|
||||
client_id: serverConfig.oauth.clientId,
|
||||
client_secret: serverConfig.oauth.clientSecret,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
client_id: clientInfo.clientId,
|
||||
client_secret: clientInfo.clientSecret,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Save registered client information
|
||||
* Called by SDK after successful dynamic registration
|
||||
*/
|
||||
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
||||
console.log(`Saving OAuth client information for server: ${this.serverName}`);
|
||||
|
||||
const scopeString = info.scope?.trim();
|
||||
const scopes =
|
||||
scopeString && scopeString.length > 0
|
||||
? scopeString.split(/\s+/).filter((value) => value.length > 0)
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const updatedConfig = await persistClientCredentials(this.serverName, {
|
||||
clientId: info.client_id,
|
||||
clientSecret: info.client_secret,
|
||||
scopes,
|
||||
});
|
||||
|
||||
if (updatedConfig) {
|
||||
this.serverConfig = updatedConfig;
|
||||
}
|
||||
|
||||
if (!scopes || scopes.length === 0) {
|
||||
await this.ensureScopesFromServer();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to persist OAuth client credentials for server ${this.serverName}:`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stored OAuth tokens
|
||||
*/
|
||||
tokens(): OAuthTokens | undefined {
|
||||
// Use cached config first, but reload if needed
|
||||
let serverConfig = this.serverConfig;
|
||||
|
||||
// If cached config doesn't have tokens, try reloading
|
||||
if (!serverConfig?.oauth?.accessToken) {
|
||||
const storedConfig = loadServerConfig(this.serverName);
|
||||
if (storedConfig) {
|
||||
this.serverConfig = storedConfig;
|
||||
serverConfig = storedConfig;
|
||||
}
|
||||
}
|
||||
|
||||
if (!serverConfig?.oauth?.accessToken) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
access_token: serverConfig.oauth.accessToken,
|
||||
token_type: 'Bearer',
|
||||
refresh_token: serverConfig.oauth.refreshToken,
|
||||
// Note: expires_in is not typically stored, only the token itself
|
||||
// The SDK will handle token refresh when needed
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Save OAuth tokens
|
||||
* Called by SDK after successful token exchange or refresh
|
||||
*/
|
||||
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||
const currentOAuth = this.serverConfig.oauth;
|
||||
const accessTokenChanged = currentOAuth?.accessToken !== tokens.access_token;
|
||||
const refreshTokenProvided = tokens.refresh_token !== undefined;
|
||||
const refreshTokenChanged =
|
||||
refreshTokenProvided && currentOAuth?.refreshToken !== tokens.refresh_token;
|
||||
const hadPending = Boolean(currentOAuth?.pendingAuthorization);
|
||||
|
||||
if (!accessTokenChanged && !refreshTokenChanged && !hadPending) {
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Saving OAuth tokens for server: ${this.serverName}`);
|
||||
|
||||
const updatedConfig = await persistTokens(this.serverName, {
|
||||
accessToken: tokens.access_token,
|
||||
refreshToken: refreshTokenProvided ? tokens.refresh_token ?? null : undefined,
|
||||
clearPendingAuthorization: hadPending,
|
||||
});
|
||||
|
||||
if (updatedConfig) {
|
||||
this.serverConfig = updatedConfig;
|
||||
}
|
||||
|
||||
this._codeVerifier = undefined;
|
||||
this._currentState = undefined;
|
||||
|
||||
const serverInfo = getServerByName(this.serverName);
|
||||
if (serverInfo) {
|
||||
serverInfo.oauth = undefined;
|
||||
}
|
||||
|
||||
console.log(`Saved OAuth tokens for server: ${this.serverName}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Redirect to authorization URL
|
||||
* In a server environment, we can't directly redirect the user
|
||||
* Instead, we store the URL in ServerInfo for the frontend to access
|
||||
*/
|
||||
async redirectToAuthorization(url: URL): Promise<void> {
|
||||
console.log('='.repeat(80));
|
||||
console.log(`OAuth Authorization Required for server: ${this.serverName}`);
|
||||
console.log(`Authorization URL: ${url.toString()}`);
|
||||
console.log('='.repeat(80));
|
||||
let state = url.searchParams.get('state') || undefined;
|
||||
|
||||
if (!state) {
|
||||
state = await this.state();
|
||||
url.searchParams.set('state', state);
|
||||
} else {
|
||||
this._currentState = state;
|
||||
}
|
||||
|
||||
const authorizationUrl = url.toString();
|
||||
|
||||
try {
|
||||
const pendingUpdate: Partial<NonNullable<ServerConfig['oauth']>['pendingAuthorization']> = {
|
||||
authorizationUrl,
|
||||
state,
|
||||
};
|
||||
|
||||
if (this._codeVerifier) {
|
||||
pendingUpdate.codeVerifier = this._codeVerifier;
|
||||
}
|
||||
|
||||
const updatedConfig = await updatePendingAuthorization(this.serverName, pendingUpdate);
|
||||
if (updatedConfig) {
|
||||
this.serverConfig = updatedConfig;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to persist pending OAuth authorization state for ${this.serverName}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
// Store the authorization URL in ServerInfo for the frontend to access
|
||||
const serverInfo = getServerByName(this.serverName);
|
||||
if (serverInfo) {
|
||||
serverInfo.status = 'oauth_required';
|
||||
serverInfo.oauth = {
|
||||
authorizationUrl,
|
||||
state,
|
||||
codeVerifier: this._codeVerifier,
|
||||
};
|
||||
console.log(`Stored OAuth authorization URL in ServerInfo for server: ${this.serverName}`);
|
||||
} else {
|
||||
console.warn(`ServerInfo not found for ${this.serverName}, cannot store authorization URL`);
|
||||
}
|
||||
|
||||
// Throw error to indicate authorization is needed
|
||||
// The error will be caught in the connection flow and handled appropriately
|
||||
throw new Error(
|
||||
`OAuth authorization required for server ${this.serverName}. Please complete OAuth flow via web UI.`,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save PKCE code verifier for later use in token exchange
|
||||
*/
|
||||
async saveCodeVerifier(verifier: string): Promise<void> {
|
||||
this._codeVerifier = verifier;
|
||||
try {
|
||||
const updatedConfig = await updatePendingAuthorization(this.serverName, { codeVerifier: verifier });
|
||||
if (updatedConfig) {
|
||||
this.serverConfig = updatedConfig;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Failed to persist OAuth code verifier for ${this.serverName}:`, error);
|
||||
}
|
||||
console.log(`Saved code verifier for server: ${this.serverName}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve PKCE code verifier for token exchange
|
||||
*/
|
||||
async codeVerifier(): Promise<string> {
|
||||
if (this._codeVerifier) {
|
||||
return this._codeVerifier;
|
||||
}
|
||||
|
||||
const storedConfig = loadServerConfig(this.serverName);
|
||||
const storedVerifier = storedConfig?.oauth?.pendingAuthorization?.codeVerifier;
|
||||
|
||||
if (storedVerifier) {
|
||||
this.serverConfig = storedConfig || this.serverConfig;
|
||||
this._codeVerifier = storedVerifier;
|
||||
return storedVerifier;
|
||||
}
|
||||
|
||||
throw new Error(`No code verifier stored for server: ${this.serverName}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidate cached OAuth credentials when the SDK detects they are no longer valid.
|
||||
* This keeps stored configuration in sync and forces a fresh authorization flow.
|
||||
*/
|
||||
async invalidateCredentials(scope: 'all' | 'client' | 'tokens' | 'verifier'): Promise<void> {
|
||||
const storedConfig = loadServerConfig(this.serverName);
|
||||
|
||||
if (!storedConfig?.oauth) {
|
||||
if (scope === 'verifier' || scope === 'all') {
|
||||
this._codeVerifier = undefined;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let currentConfig = storedConfig as ServerConfigWithOAuth;
|
||||
const assignUpdatedConfig = (updated?: ServerConfigWithOAuth) => {
|
||||
if (updated) {
|
||||
currentConfig = updated;
|
||||
this.serverConfig = updated;
|
||||
} else {
|
||||
this.serverConfig = currentConfig;
|
||||
}
|
||||
};
|
||||
|
||||
assignUpdatedConfig(currentConfig);
|
||||
let changed = false;
|
||||
|
||||
if (scope === 'tokens' || scope === 'all') {
|
||||
if (currentConfig.oauth.accessToken || currentConfig.oauth.refreshToken) {
|
||||
const updated = await clearOAuthData(this.serverName, 'tokens');
|
||||
assignUpdatedConfig(updated);
|
||||
changed = true;
|
||||
console.warn(`Cleared OAuth tokens for server: ${this.serverName}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (scope === 'client' || scope === 'all') {
|
||||
const supportsDynamicClient = currentConfig.oauth.dynamicRegistration?.enabled === true;
|
||||
|
||||
if (supportsDynamicClient && (currentConfig.oauth.clientId || currentConfig.oauth.clientSecret)) {
|
||||
removeRegisteredClient(this.serverName);
|
||||
const updated = await clearOAuthData(this.serverName, 'client');
|
||||
assignUpdatedConfig(updated);
|
||||
changed = true;
|
||||
console.warn(`Cleared OAuth client registration for server: ${this.serverName}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (scope === 'verifier' || scope === 'all') {
|
||||
this._codeVerifier = undefined;
|
||||
this._currentState = undefined;
|
||||
if (currentConfig.oauth.pendingAuthorization) {
|
||||
const updated = await clearOAuthData(this.serverName, 'verifier');
|
||||
assignUpdatedConfig(updated);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
this._currentState = undefined;
|
||||
const serverInfo = getServerByName(this.serverName);
|
||||
if (serverInfo) {
|
||||
serverInfo.status = 'oauth_required';
|
||||
serverInfo.oauth = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const prepopulateScopesIfMissing = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
): Promise<void> => {
|
||||
if (!serverConfig.oauth || serverConfig.oauth.scopes?.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!serverConfig.url) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const scopes = await fetchScopesFromServer(serverConfig.url);
|
||||
if (scopes && scopes.length > 0) {
|
||||
const updatedConfig = await mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
oauth.scopes = scopes;
|
||||
});
|
||||
|
||||
if (!serverConfig.oauth) {
|
||||
serverConfig.oauth = {};
|
||||
}
|
||||
serverConfig.oauth.scopes = scopes;
|
||||
|
||||
if (updatedConfig) {
|
||||
console.log(`Stored auto-detected scopes for ${serverName}: ${scopes.join(', ')}`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to auto-detect scopes for ${serverName} during provider initialization: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Create an OAuth provider for a server if OAuth is configured
|
||||
*
|
||||
* @param serverName - Name of the server
|
||||
* @param serverConfig - Server configuration
|
||||
* @returns OAuthClientProvider instance or undefined if OAuth not configured
|
||||
*/
|
||||
export const createOAuthProvider = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
): Promise<OAuthClientProvider | undefined> => {
|
||||
// Ensure scopes are pre-populated if dynamic registration already ran previously
|
||||
await prepopulateScopesIfMissing(serverName, serverConfig);
|
||||
|
||||
// Initialize OAuth for the server (performs registration if needed)
|
||||
// This ensures the client is registered before the SDK tries to use it
|
||||
try {
|
||||
await initializeOAuthForServer(serverName, serverConfig);
|
||||
} catch (error) {
|
||||
console.warn(`Failed to initialize OAuth for server ${serverName}:`, error);
|
||||
// Continue anyway - the SDK might be able to handle it
|
||||
}
|
||||
|
||||
// Create and return the provider
|
||||
const provider = new MCPHubOAuthProvider(serverName, serverConfig);
|
||||
|
||||
console.log(`Created OAuth provider for server: ${serverName}`);
|
||||
return provider;
|
||||
};
|
||||
@@ -10,7 +10,10 @@ import {
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { ServerInfo, ServerConfig, Tool } from '../types/index.js';
|
||||
import { loadSettings, expandEnvVars, replaceEnvVars, getNameSeparator } from '../config/index.js';
|
||||
import config from '../config/index.js';
|
||||
@@ -21,6 +24,8 @@ import { OpenAPIClient } from '../clients/openapi.js';
|
||||
import { RequestContextService } from './requestContextService.js';
|
||||
import { getDataService } from './services.js';
|
||||
import { getServerDao, ServerConfigWithName } from '../dao/index.js';
|
||||
import { initializeAllOAuthClients } from './oauthService.js';
|
||||
import { createOAuthProvider } from './mcpOAuthProvider.js';
|
||||
|
||||
const servers: { [sessionId: string]: Server } = {};
|
||||
|
||||
@@ -59,6 +64,10 @@ const setupKeepAlive = (serverInfo: ServerInfo, serverConfig: ServerConfig): voi
|
||||
};
|
||||
|
||||
export const initUpstreamServers = async (): Promise<void> => {
|
||||
// Initialize OAuth clients for servers with dynamic registration
|
||||
await initializeAllOAuthClients();
|
||||
|
||||
// Register all tools from upstream servers
|
||||
await registerAllTools(true);
|
||||
};
|
||||
|
||||
@@ -155,28 +164,48 @@ export const cleanupAllServers = (): void => {
|
||||
};
|
||||
|
||||
// Helper function to create transport based on server configuration
|
||||
const createTransportFromConfig = (name: string, conf: ServerConfig): any => {
|
||||
export const createTransportFromConfig = async (name: string, conf: ServerConfig): Promise<any> => {
|
||||
let transport;
|
||||
|
||||
if (conf.type === 'streamable-http') {
|
||||
const options: any = {};
|
||||
if (conf.headers && Object.keys(conf.headers).length > 0) {
|
||||
const options: StreamableHTTPClientTransportOptions = {};
|
||||
const headers = conf.headers ? replaceEnvVars(conf.headers) : {};
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
options.requestInit = {
|
||||
headers: replaceEnvVars(conf.headers),
|
||||
headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Create OAuth provider if configured - SDK will handle authentication automatically
|
||||
const authProvider = await createOAuthProvider(name, conf);
|
||||
if (authProvider) {
|
||||
options.authProvider = authProvider;
|
||||
console.log(`OAuth provider configured for server: ${name}`);
|
||||
}
|
||||
|
||||
transport = new StreamableHTTPClientTransport(new URL(conf.url || ''), options);
|
||||
} else if (conf.url) {
|
||||
// SSE transport
|
||||
const options: any = {};
|
||||
if (conf.headers && Object.keys(conf.headers).length > 0) {
|
||||
const headers = conf.headers ? replaceEnvVars(conf.headers) : {};
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
options.eventSourceInit = {
|
||||
headers: replaceEnvVars(conf.headers),
|
||||
headers,
|
||||
};
|
||||
options.requestInit = {
|
||||
headers: replaceEnvVars(conf.headers),
|
||||
headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Create OAuth provider if configured - SDK will handle authentication automatically
|
||||
const authProvider = await createOAuthProvider(name, conf);
|
||||
if (authProvider) {
|
||||
options.authProvider = authProvider;
|
||||
console.log(`OAuth provider configured for server: ${name}`);
|
||||
}
|
||||
|
||||
transport = new SSEClientTransport(new URL(conf.url), options);
|
||||
} else if (conf.command && conf.args) {
|
||||
// Stdio transport
|
||||
@@ -269,7 +298,7 @@ const callToolWithReconnect = async (
|
||||
}
|
||||
|
||||
// Recreate transport using helper function
|
||||
const newTransport = createTransportFromConfig(serverInfo.name, server);
|
||||
const newTransport = await createTransportFromConfig(serverInfo.name, server);
|
||||
|
||||
// Create new client
|
||||
const client = new Client(
|
||||
@@ -345,59 +374,143 @@ export const initializeClientsFromSettings = async (
|
||||
): Promise<ServerInfo[]> => {
|
||||
const allServers: ServerConfigWithName[] = await serverDao.findAll();
|
||||
const existingServerInfos = serverInfos;
|
||||
serverInfos = [];
|
||||
const nextServerInfos: ServerInfo[] = [];
|
||||
|
||||
for (const conf of allServers) {
|
||||
const { name } = conf;
|
||||
// Skip disabled servers
|
||||
if (conf.enabled === false) {
|
||||
console.log(`Skipping disabled server: ${name}`);
|
||||
serverInfos.push({
|
||||
name,
|
||||
owner: conf.owner,
|
||||
status: 'disconnected',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
enabled: false,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if server is already connected
|
||||
const existingServer = existingServerInfos.find(
|
||||
(s) => s.name === name && s.status === 'connected',
|
||||
);
|
||||
if (existingServer && (!serverName || serverName !== name)) {
|
||||
serverInfos.push({
|
||||
...existingServer,
|
||||
enabled: conf.enabled === undefined ? true : conf.enabled,
|
||||
});
|
||||
console.log(`Server '${name}' is already connected.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
let transport;
|
||||
let openApiClient;
|
||||
if (conf.type === 'openapi') {
|
||||
// Handle OpenAPI type servers
|
||||
if (!conf.openapi?.url && !conf.openapi?.schema) {
|
||||
console.warn(
|
||||
`Skipping OpenAPI server '${name}': missing OpenAPI specification URL or schema`,
|
||||
);
|
||||
serverInfos.push({
|
||||
try {
|
||||
for (const conf of allServers) {
|
||||
const { name } = conf;
|
||||
// Skip disabled servers
|
||||
if (conf.enabled === false) {
|
||||
console.log(`Skipping disabled server: ${name}`);
|
||||
nextServerInfos.push({
|
||||
name,
|
||||
owner: conf.owner,
|
||||
status: 'disconnected',
|
||||
error: 'Missing OpenAPI specification URL or schema',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
enabled: false,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if server is already connected
|
||||
const existingServer = existingServerInfos.find(
|
||||
(s) => s.name === name && s.status === 'connected',
|
||||
);
|
||||
if (existingServer && (!serverName || serverName !== name)) {
|
||||
nextServerInfos.push({
|
||||
...existingServer,
|
||||
enabled: conf.enabled === undefined ? true : conf.enabled,
|
||||
});
|
||||
console.log(`Server '${name}' is already connected.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
let transport;
|
||||
let openApiClient;
|
||||
if (conf.type === 'openapi') {
|
||||
// Handle OpenAPI type servers
|
||||
if (!conf.openapi?.url && !conf.openapi?.schema) {
|
||||
console.warn(
|
||||
`Skipping OpenAPI server '${name}': missing OpenAPI specification URL or schema`,
|
||||
);
|
||||
nextServerInfos.push({
|
||||
name,
|
||||
owner: conf.owner,
|
||||
status: 'disconnected',
|
||||
error: 'Missing OpenAPI specification URL or schema',
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create server info first and keep reference to it
|
||||
const serverInfo: ServerInfo = {
|
||||
name,
|
||||
owner: conf.owner,
|
||||
status: 'connecting',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
enabled: conf.enabled === undefined ? true : conf.enabled,
|
||||
config: conf, // Store reference to original config for OpenAPI passthrough headers
|
||||
};
|
||||
nextServerInfos.push(serverInfo);
|
||||
|
||||
try {
|
||||
// Create OpenAPI client instance
|
||||
openApiClient = new OpenAPIClient(conf);
|
||||
|
||||
console.log(`Initializing OpenAPI server: ${name}...`);
|
||||
|
||||
// Perform async initialization
|
||||
await openApiClient.initialize();
|
||||
|
||||
// Convert OpenAPI tools to MCP tool format
|
||||
const openApiTools = openApiClient.getTools();
|
||||
const mcpTools: Tool[] = openApiTools.map((tool) => ({
|
||||
name: `${name}${getNameSeparator()}${tool.name}`,
|
||||
description: tool.description,
|
||||
inputSchema: cleanInputSchema(tool.inputSchema),
|
||||
}));
|
||||
|
||||
// Update server info with successful initialization
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.tools = mcpTools;
|
||||
serverInfo.openApiClient = openApiClient;
|
||||
|
||||
console.log(
|
||||
`Successfully initialized OpenAPI server: ${name} with ${mcpTools.length} tools`,
|
||||
);
|
||||
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, mcpTools);
|
||||
continue;
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize OpenAPI server ${name}:`, error);
|
||||
|
||||
// Update the already pushed server info with error status
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to initialize OpenAPI server: ${error}`;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
transport = await createTransportFromConfig(name, conf);
|
||||
}
|
||||
|
||||
const client = new Client(
|
||||
{
|
||||
name: `mcp-client-${name}`,
|
||||
version: '1.0.0',
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
prompts: {},
|
||||
resources: {},
|
||||
tools: {},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const initRequestOptions = isInit
|
||||
? {
|
||||
timeout: Number(config.initTimeout) || 60000,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
// Get request options from server configuration, with fallbacks
|
||||
const serverRequestOptions = conf.options || {};
|
||||
const requestOptions = {
|
||||
timeout: serverRequestOptions.timeout || 60000,
|
||||
resetTimeoutOnProgress: serverRequestOptions.resetTimeoutOnProgress || false,
|
||||
maxTotalTimeout: serverRequestOptions.maxTotalTimeout,
|
||||
};
|
||||
|
||||
// Create server info first and keep reference to it
|
||||
const serverInfo: ServerInfo = {
|
||||
name,
|
||||
@@ -406,169 +519,122 @@ export const initializeClientsFromSettings = async (
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
client,
|
||||
transport,
|
||||
options: requestOptions,
|
||||
createTime: Date.now(),
|
||||
enabled: conf.enabled === undefined ? true : conf.enabled,
|
||||
config: conf, // Store reference to original config for OpenAPI passthrough headers
|
||||
config: conf, // Store reference to original config
|
||||
};
|
||||
serverInfos.push(serverInfo);
|
||||
|
||||
try {
|
||||
// Create OpenAPI client instance
|
||||
openApiClient = new OpenAPIClient(conf);
|
||||
|
||||
console.log(`Initializing OpenAPI server: ${name}...`);
|
||||
|
||||
// Perform async initialization
|
||||
await openApiClient.initialize();
|
||||
|
||||
// Convert OpenAPI tools to MCP tool format
|
||||
const openApiTools = openApiClient.getTools();
|
||||
const mcpTools: Tool[] = openApiTools.map((tool) => ({
|
||||
name: `${name}${getNameSeparator()}${tool.name}`,
|
||||
description: tool.description,
|
||||
inputSchema: cleanInputSchema(tool.inputSchema),
|
||||
}));
|
||||
|
||||
// Update server info with successful initialization
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.tools = mcpTools;
|
||||
serverInfo.openApiClient = openApiClient;
|
||||
|
||||
console.log(
|
||||
`Successfully initialized OpenAPI server: ${name} with ${mcpTools.length} tools`,
|
||||
);
|
||||
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, mcpTools);
|
||||
continue;
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize OpenAPI server ${name}:`, error);
|
||||
|
||||
// Update the already pushed server info with error status
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to initialize OpenAPI server: ${error}`;
|
||||
continue;
|
||||
const pendingAuth = conf.oauth?.pendingAuthorization;
|
||||
if (pendingAuth) {
|
||||
serverInfo.status = 'oauth_required';
|
||||
serverInfo.error = null;
|
||||
serverInfo.oauth = {
|
||||
authorizationUrl: pendingAuth.authorizationUrl,
|
||||
state: pendingAuth.state,
|
||||
codeVerifier: pendingAuth.codeVerifier,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
transport = createTransportFromConfig(name, conf);
|
||||
nextServerInfos.push(serverInfo);
|
||||
|
||||
client
|
||||
.connect(transport, initRequestOptions || requestOptions)
|
||||
.then(() => {
|
||||
console.log(`Successfully connected client for server: ${name}`);
|
||||
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
|
||||
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
|
||||
|
||||
let dataError: Error | null = null;
|
||||
if (capabilities?.tools) {
|
||||
client
|
||||
.listTools({}, initRequestOptions || requestOptions)
|
||||
.then((tools) => {
|
||||
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
|
||||
serverInfo.tools = tools.tools.map((tool) => ({
|
||||
name: `${name}${getNameSeparator()}${tool.name}`,
|
||||
description: tool.description || '',
|
||||
inputSchema: cleanInputSchema(tool.inputSchema || {}),
|
||||
}));
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities?.prompts) {
|
||||
client
|
||||
.listPrompts({}, initRequestOptions || requestOptions)
|
||||
.then((prompts) => {
|
||||
console.log(
|
||||
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
|
||||
);
|
||||
serverInfo.prompts = prompts.prompts.map((prompt) => ({
|
||||
name: `${name}${getNameSeparator()}${prompt.name}`,
|
||||
title: prompt.title,
|
||||
description: prompt.description,
|
||||
arguments: prompt.arguments,
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
if (!dataError) {
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.error = null;
|
||||
|
||||
// Set up keep-alive ping for SSE connections
|
||||
setupKeepAlive(serverInfo, conf);
|
||||
} else {
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to list data: ${dataError} `;
|
||||
}
|
||||
})
|
||||
.catch(async (error) => {
|
||||
// Check if this is an OAuth authorization error
|
||||
const isOAuthError =
|
||||
error?.message?.includes('OAuth authorization required') ||
|
||||
error?.message?.includes('Authorization required');
|
||||
|
||||
if (isOAuthError) {
|
||||
// OAuth provider should have already set the status to 'oauth_required'
|
||||
// and stored the authorization URL in serverInfo.oauth
|
||||
console.log(
|
||||
`OAuth authorization required for server ${name}. Status should be set to 'oauth_required'.`,
|
||||
);
|
||||
// Make sure status is set correctly
|
||||
if (serverInfo.status !== 'oauth_required') {
|
||||
serverInfo.status = 'oauth_required';
|
||||
}
|
||||
serverInfo.error = null;
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to connect client for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
// Other connection errors
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to connect: ${error.stack} `;
|
||||
}
|
||||
});
|
||||
console.log(`Initialized client for server: ${name}`);
|
||||
}
|
||||
|
||||
const client = new Client(
|
||||
{
|
||||
name: `mcp-client-${name}`,
|
||||
version: '1.0.0',
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
prompts: {},
|
||||
resources: {},
|
||||
tools: {},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const initRequestOptions = isInit
|
||||
? {
|
||||
timeout: Number(config.initTimeout) || 60000,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
// Get request options from server configuration, with fallbacks
|
||||
const serverRequestOptions = conf.options || {};
|
||||
const requestOptions = {
|
||||
timeout: serverRequestOptions.timeout || 60000,
|
||||
resetTimeoutOnProgress: serverRequestOptions.resetTimeoutOnProgress || false,
|
||||
maxTotalTimeout: serverRequestOptions.maxTotalTimeout,
|
||||
};
|
||||
|
||||
// Create server info first and keep reference to it
|
||||
const serverInfo: ServerInfo = {
|
||||
name,
|
||||
owner: conf.owner,
|
||||
status: 'connecting',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
client,
|
||||
transport,
|
||||
options: requestOptions,
|
||||
createTime: Date.now(),
|
||||
config: conf, // Store reference to original config
|
||||
};
|
||||
serverInfos.push(serverInfo);
|
||||
|
||||
client
|
||||
.connect(transport, initRequestOptions || requestOptions)
|
||||
.then(() => {
|
||||
console.log(`Successfully connected client for server: ${name}`);
|
||||
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
|
||||
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
|
||||
|
||||
let dataError: Error | null = null;
|
||||
if (capabilities?.tools) {
|
||||
client
|
||||
.listTools({}, initRequestOptions || requestOptions)
|
||||
.then((tools) => {
|
||||
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
|
||||
serverInfo.tools = tools.tools.map((tool) => ({
|
||||
name: `${name}${getNameSeparator()}${tool.name}`,
|
||||
description: tool.description || '',
|
||||
inputSchema: cleanInputSchema(tool.inputSchema || {}),
|
||||
}));
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities?.prompts) {
|
||||
client
|
||||
.listPrompts({}, initRequestOptions || requestOptions)
|
||||
.then((prompts) => {
|
||||
console.log(
|
||||
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
|
||||
);
|
||||
serverInfo.prompts = prompts.prompts.map((prompt) => ({
|
||||
name: `${name}${getNameSeparator()}${prompt.name}`,
|
||||
title: prompt.title,
|
||||
description: prompt.description,
|
||||
arguments: prompt.arguments,
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
if (!dataError) {
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.error = null;
|
||||
|
||||
// Set up keep-alive ping for SSE connections
|
||||
setupKeepAlive(serverInfo, conf);
|
||||
} else {
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to list data: ${dataError} `;
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to connect client for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to connect: ${error.stack} `;
|
||||
});
|
||||
console.log(`Initialized client for server: ${name}`);
|
||||
} catch (error) {
|
||||
// Restore previous state if initialization fails to avoid exposing an empty server list
|
||||
serverInfos = existingServerInfos;
|
||||
throw error;
|
||||
}
|
||||
|
||||
serverInfos = nextServerInfos;
|
||||
return serverInfos;
|
||||
};
|
||||
|
||||
@@ -584,39 +650,48 @@ export const getServersInfo = async (): Promise<Omit<ServerInfo, 'client' | 'tra
|
||||
const filterServerInfos: ServerInfo[] = dataService.filterData
|
||||
? dataService.filterData(serverInfos)
|
||||
: serverInfos;
|
||||
const infos = filterServerInfos.map(({ name, status, tools, prompts, createTime, error }) => {
|
||||
const serverConfig = allServers.find((server) => server.name === name);
|
||||
const enabled = serverConfig ? serverConfig.enabled !== false : true;
|
||||
const infos = filterServerInfos.map(
|
||||
({ name, status, tools, prompts, createTime, error, oauth }) => {
|
||||
const serverConfig = allServers.find((server) => server.name === name);
|
||||
const enabled = serverConfig ? serverConfig.enabled !== false : true;
|
||||
|
||||
// Add enabled status and custom description to each tool
|
||||
const toolsWithEnabled = tools.map((tool) => {
|
||||
const toolConfig = serverConfig?.tools?.[tool.name];
|
||||
return {
|
||||
...tool,
|
||||
description: toolConfig?.description || tool.description, // Use custom description if available
|
||||
enabled: toolConfig?.enabled !== false, // Default to true if not explicitly disabled
|
||||
};
|
||||
});
|
||||
|
||||
const promptsWithEnabled = prompts.map((prompt) => {
|
||||
const promptConfig = serverConfig?.prompts?.[prompt.name];
|
||||
return {
|
||||
...prompt,
|
||||
description: promptConfig?.description || prompt.description, // Use custom description if available
|
||||
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
|
||||
};
|
||||
});
|
||||
|
||||
// Add enabled status and custom description to each tool
|
||||
const toolsWithEnabled = tools.map((tool) => {
|
||||
const toolConfig = serverConfig?.tools?.[tool.name];
|
||||
return {
|
||||
...tool,
|
||||
description: toolConfig?.description || tool.description, // Use custom description if available
|
||||
enabled: toolConfig?.enabled !== false, // Default to true if not explicitly disabled
|
||||
name,
|
||||
status,
|
||||
error,
|
||||
tools: toolsWithEnabled,
|
||||
prompts: promptsWithEnabled,
|
||||
createTime,
|
||||
enabled,
|
||||
oauth: oauth
|
||||
? {
|
||||
authorizationUrl: oauth.authorizationUrl,
|
||||
state: oauth.state,
|
||||
// Don't expose codeVerifier to frontend for security
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
});
|
||||
|
||||
const promptsWithEnabled = prompts.map((prompt) => {
|
||||
const promptConfig = serverConfig?.prompts?.[prompt.name];
|
||||
return {
|
||||
...prompt,
|
||||
description: promptConfig?.description || prompt.description, // Use custom description if available
|
||||
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
name,
|
||||
status,
|
||||
error,
|
||||
tools: toolsWithEnabled,
|
||||
prompts: promptsWithEnabled,
|
||||
createTime,
|
||||
enabled,
|
||||
};
|
||||
});
|
||||
},
|
||||
);
|
||||
infos.sort((a, b) => {
|
||||
if (a.enabled === b.enabled) return 0;
|
||||
return a.enabled ? -1 : 1;
|
||||
@@ -629,6 +704,51 @@ export const getServerByName = (name: string): ServerInfo | undefined => {
|
||||
return serverInfos.find((serverInfo) => serverInfo.name === name);
|
||||
};
|
||||
|
||||
// Get server by OAuth state parameter
|
||||
export const getServerByOAuthState = (state: string): ServerInfo | undefined => {
|
||||
return serverInfos.find((serverInfo) => serverInfo.oauth?.state === state);
|
||||
};
|
||||
|
||||
/**
|
||||
* Reconnect a server after OAuth authorization or configuration change
|
||||
* This will close the existing connection and reinitialize the server
|
||||
*/
|
||||
export const reconnectServer = async (serverName: string): Promise<void> => {
|
||||
console.log(`Reconnecting server: ${serverName}`);
|
||||
|
||||
const serverInfo = getServerByName(serverName);
|
||||
if (!serverInfo) {
|
||||
throw new Error(`Server not found: ${serverName}`);
|
||||
}
|
||||
|
||||
// Close existing connection if any
|
||||
if (serverInfo.client) {
|
||||
try {
|
||||
serverInfo.client.close();
|
||||
} catch (error) {
|
||||
console.warn(`Error closing client for server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (serverInfo.transport) {
|
||||
try {
|
||||
serverInfo.transport.close();
|
||||
} catch (error) {
|
||||
console.warn(`Error closing transport for server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (serverInfo.keepAliveIntervalId) {
|
||||
clearInterval(serverInfo.keepAliveIntervalId);
|
||||
serverInfo.keepAliveIntervalId = undefined;
|
||||
}
|
||||
|
||||
// Reinitialize the server
|
||||
await initializeClientsFromSettings(false, serverName);
|
||||
|
||||
console.log(`Successfully reconnected server: ${serverName}`);
|
||||
};
|
||||
|
||||
// Filter tools by server configuration
|
||||
const filterToolsByConfig = async (serverName: string, tools: Tool[]): Promise<Tool[]> => {
|
||||
const serverConfig = await serverDao.findById(serverName);
|
||||
|
||||
584
src/services/oauthClientRegistration.ts
Normal file
584
src/services/oauthClientRegistration.ts
Normal file
@@ -0,0 +1,584 @@
|
||||
/**
|
||||
* OAuth 2.0 Dynamic Client Registration Service
|
||||
*
|
||||
* Implements dynamic client registration for upstream MCP servers based on:
|
||||
* - RFC7591: OAuth 2.0 Dynamic Client Registration Protocol
|
||||
* - RFC8414: OAuth 2.0 Authorization Server Metadata
|
||||
* - MCP Authorization Specification
|
||||
*
|
||||
* Uses the standard openid-client library for OAuth operations.
|
||||
*/
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { ServerConfig } from '../types/index.js';
|
||||
import {
|
||||
mutateOAuthSettings,
|
||||
persistClientCredentials,
|
||||
persistTokens,
|
||||
} from './oauthSettingsStore.js';
|
||||
|
||||
interface RegisteredClientInfo {
|
||||
config: client.Configuration;
|
||||
clientId: string;
|
||||
clientSecret?: string;
|
||||
registrationAccessToken?: string;
|
||||
registrationClientUri?: string;
|
||||
expiresAt?: number;
|
||||
metadata: any;
|
||||
}
|
||||
|
||||
// Cache for registered clients to avoid re-registering on every restart
|
||||
const registeredClients = new Map<string, RegisteredClientInfo>();
|
||||
|
||||
export const removeRegisteredClient = (serverName: string): void => {
|
||||
registeredClients.delete(serverName);
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse WWW-Authenticate header to extract resource server metadata URL
|
||||
* Following RFC9728 Protected Resource Metadata specification
|
||||
*
|
||||
* Example header: WWW-Authenticate: Bearer resource="https://mcp.example.com/.well-known/oauth-protected-resource"
|
||||
*/
|
||||
export const parseWWWAuthenticateHeader = (header: string): string | null => {
|
||||
if (!header || !header.toLowerCase().startsWith('bearer ')) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Extract resource parameter from WWW-Authenticate header
|
||||
const resourceMatch = header.match(/resource="([^"]+)"/i);
|
||||
if (resourceMatch && resourceMatch[1]) {
|
||||
return resourceMatch[1];
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetch protected resource metadata from MCP server
|
||||
* Following RFC9728 section 3
|
||||
*
|
||||
* @param resourceMetadataUrl - URL to fetch resource metadata (from WWW-Authenticate header)
|
||||
* @returns Authorization server URLs and other metadata
|
||||
*/
|
||||
export const fetchProtectedResourceMetadata = async (
|
||||
resourceMetadataUrl: string,
|
||||
): Promise<{
|
||||
authorization_servers: string[];
|
||||
resource?: string;
|
||||
[key: string]: any;
|
||||
}> => {
|
||||
try {
|
||||
console.log(`Fetching protected resource metadata from: ${resourceMetadataUrl}`);
|
||||
|
||||
const response = await fetch(resourceMetadataUrl, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch resource metadata: ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const metadata = await response.json();
|
||||
|
||||
if (!metadata.authorization_servers || !Array.isArray(metadata.authorization_servers)) {
|
||||
throw new Error('Invalid resource metadata: missing authorization_servers field');
|
||||
}
|
||||
|
||||
console.log(`Found ${metadata.authorization_servers.length} authorization server(s)`);
|
||||
return metadata;
|
||||
} catch (error) {
|
||||
console.warn(`Failed to fetch protected resource metadata:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetch scopes from protected resource metadata by trying the well-known URL
|
||||
*
|
||||
* @param serverUrl - The MCP server URL
|
||||
* @returns Array of supported scopes or undefined if not available
|
||||
*/
|
||||
export const fetchScopesFromServer = async (serverUrl: string): Promise<string[] | undefined> => {
|
||||
try {
|
||||
// Construct the well-known protected resource metadata URL
|
||||
// Format: https://example.com/.well-known/oauth-protected-resource/path/to/resource
|
||||
const url = new URL(serverUrl);
|
||||
const resourcePath = url.pathname + url.search;
|
||||
const wellKnownUrl = `${url.origin}/.well-known/oauth-protected-resource${resourcePath}`;
|
||||
|
||||
console.log(`Attempting to fetch scopes from: ${wellKnownUrl}`);
|
||||
|
||||
const metadata = await fetchProtectedResourceMetadata(wellKnownUrl);
|
||||
|
||||
if (metadata.scopes_supported && Array.isArray(metadata.scopes_supported)) {
|
||||
console.log(`Fetched scopes from server: ${metadata.scopes_supported.join(', ')}`);
|
||||
return metadata.scopes_supported as string[];
|
||||
}
|
||||
|
||||
return undefined;
|
||||
} catch (error) {
|
||||
console.log(
|
||||
`Could not fetch scopes from server (this is normal if not using OAuth discovery): ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Auto-detect OAuth configuration from 401 response
|
||||
* Following MCP Authorization Specification for automatic discovery
|
||||
*
|
||||
* @param wwwAuthenticateHeader - The WWW-Authenticate header value from 401 response
|
||||
* @param serverUrl - The MCP server URL that returned 401
|
||||
* @returns Issuer URL and resource URL for OAuth configuration
|
||||
*/
|
||||
export const autoDetectOAuthConfig = async (
|
||||
wwwAuthenticateHeader: string,
|
||||
serverUrl: string,
|
||||
): Promise<{ issuer: string; resource: string; scopes?: string[] } | null> => {
|
||||
try {
|
||||
// Step 1: Parse WWW-Authenticate header to get resource metadata URL
|
||||
const resourceMetadataUrl = parseWWWAuthenticateHeader(wwwAuthenticateHeader);
|
||||
|
||||
if (!resourceMetadataUrl) {
|
||||
console.log('No resource metadata URL found in WWW-Authenticate header');
|
||||
return null;
|
||||
}
|
||||
|
||||
// Step 2: Fetch protected resource metadata
|
||||
const resourceMetadata = await fetchProtectedResourceMetadata(resourceMetadataUrl);
|
||||
|
||||
// Step 3: Select first authorization server (TODO: implement proper selection logic)
|
||||
const issuer = resourceMetadata.authorization_servers[0];
|
||||
|
||||
if (!issuer) {
|
||||
throw new Error('No authorization servers found in resource metadata');
|
||||
}
|
||||
|
||||
// Step 4: Determine resource URL (canonical URI of MCP server)
|
||||
const resource = resourceMetadata.resource || new URL(serverUrl).origin;
|
||||
|
||||
// Step 5: Extract supported scopes from resource metadata
|
||||
const scopes = resourceMetadata.scopes_supported as string[] | undefined;
|
||||
|
||||
console.log(`Auto-detected OAuth configuration:`);
|
||||
console.log(` Issuer: ${issuer}`);
|
||||
console.log(` Resource: ${resource}`);
|
||||
if (scopes && scopes.length > 0) {
|
||||
console.log(` Scopes: ${scopes.join(', ')}`);
|
||||
}
|
||||
|
||||
return { issuer, resource, scopes };
|
||||
} catch (error) {
|
||||
console.error('Failed to auto-detect OAuth configuration:', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Perform OAuth 2.0 issuer discovery to get authorization server metadata
|
||||
*/
|
||||
export const discoverIssuer = async (
|
||||
issuerUrl: string,
|
||||
clientId: string = 'mcphub-temp',
|
||||
clientSecret?: string,
|
||||
): Promise<client.Configuration> => {
|
||||
try {
|
||||
console.log(`Discovering OAuth issuer: ${issuerUrl}`);
|
||||
const server = new URL(issuerUrl);
|
||||
|
||||
const clientAuth = clientSecret ? client.ClientSecretPost(clientSecret) : client.None();
|
||||
|
||||
const config = await client.discovery(server, clientId, undefined, clientAuth);
|
||||
console.log(`Successfully discovered OAuth issuer: ${issuerUrl}`);
|
||||
return config;
|
||||
} catch (error) {
|
||||
console.error(`Failed to discover OAuth issuer ${issuerUrl}:`, error);
|
||||
throw new Error(
|
||||
`OAuth issuer discovery failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Register a new OAuth client dynamically using RFC7591
|
||||
* Can be called with auto-detected configuration from 401 response
|
||||
*/
|
||||
export const registerClient = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
autoDetectedIssuer?: string,
|
||||
autoDetectedScopes?: string[],
|
||||
): Promise<RegisteredClientInfo> => {
|
||||
// Check if we already have a registered client for this server
|
||||
const cached = registeredClients.get(serverName);
|
||||
if (cached && (!cached.expiresAt || cached.expiresAt > Date.now())) {
|
||||
console.log(`Using cached OAuth client for server: ${serverName}`);
|
||||
return cached;
|
||||
}
|
||||
|
||||
const dynamicConfig = serverConfig.oauth?.dynamicRegistration;
|
||||
|
||||
try {
|
||||
let serverUrl: URL;
|
||||
|
||||
// Step 1: Determine the authorization server URL
|
||||
// Priority: autoDetectedIssuer > configured issuer > registration endpoint
|
||||
const issuerUrl = autoDetectedIssuer || dynamicConfig?.issuer;
|
||||
|
||||
if (issuerUrl) {
|
||||
serverUrl = new URL(issuerUrl);
|
||||
} else if (dynamicConfig?.registrationEndpoint) {
|
||||
// Extract server URL from registration endpoint
|
||||
const regUrl = new URL(dynamicConfig.registrationEndpoint);
|
||||
serverUrl = new URL(`${regUrl.protocol}//${regUrl.host}`);
|
||||
} else {
|
||||
throw new Error(
|
||||
`Cannot register OAuth client: no issuer URL available. Either provide 'issuer' in configuration or ensure server returns proper 401 with WWW-Authenticate header.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 2: Prepare client metadata for registration
|
||||
const metadata = dynamicConfig?.metadata || {};
|
||||
|
||||
// Determine scopes: priority is metadata.scope > autoDetectedScopes > configured scopes > 'openid'
|
||||
let scopeValue: string;
|
||||
if (metadata.scope) {
|
||||
scopeValue = metadata.scope;
|
||||
} else if (autoDetectedScopes && autoDetectedScopes.length > 0) {
|
||||
scopeValue = autoDetectedScopes.join(' ');
|
||||
} else if (serverConfig.oauth?.scopes) {
|
||||
scopeValue = serverConfig.oauth.scopes.join(' ');
|
||||
} else {
|
||||
scopeValue = 'openid';
|
||||
}
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_name: metadata.client_name || `MCPHub - ${serverName}`,
|
||||
redirect_uris: metadata.redirect_uris || ['http://localhost:3000/oauth/callback'],
|
||||
grant_types: metadata.grant_types || ['authorization_code', 'refresh_token'],
|
||||
response_types: metadata.response_types || ['code'],
|
||||
token_endpoint_auth_method: metadata.token_endpoint_auth_method || 'client_secret_post',
|
||||
scope: scopeValue,
|
||||
...metadata, // Include any additional custom metadata
|
||||
};
|
||||
|
||||
console.log(`Registering OAuth client for server: ${serverName}`);
|
||||
console.log(`Server URL: ${serverUrl}`);
|
||||
console.log(`Client metadata:`, JSON.stringify(clientMetadata, null, 2));
|
||||
|
||||
// Step 3: Perform dynamic client registration
|
||||
const clientAuth = dynamicConfig?.initialAccessToken
|
||||
? client.ClientSecretPost(dynamicConfig.initialAccessToken)
|
||||
: client.None();
|
||||
|
||||
const config = await client.dynamicClientRegistration(serverUrl, clientMetadata, clientAuth);
|
||||
|
||||
console.log(`Successfully registered OAuth client for server: ${serverName}`);
|
||||
|
||||
// Extract client ID from the configuration
|
||||
const clientId = (config as any).client_id || (config as any).clientId;
|
||||
console.log(`Client ID: ${clientId}`);
|
||||
|
||||
// Step 4: Store registered client information
|
||||
const clientInfo: RegisteredClientInfo = {
|
||||
config,
|
||||
clientId,
|
||||
clientSecret: (config as any).client_secret, // Access client secret if available
|
||||
registrationAccessToken: (config as any).registrationAccessToken,
|
||||
registrationClientUri: (config as any).registrationClientUri,
|
||||
expiresAt: (config as any).client_secret_expires_at
|
||||
? (config as any).client_secret_expires_at * 1000
|
||||
: undefined,
|
||||
metadata: config,
|
||||
};
|
||||
|
||||
// Cache the registered client
|
||||
registeredClients.set(serverName, clientInfo);
|
||||
|
||||
// Persist the client credentials and scopes to configuration
|
||||
const persistedConfig = await persistClientCredentials(serverName, {
|
||||
clientId,
|
||||
clientSecret: clientInfo.clientSecret,
|
||||
scopes: autoDetectedScopes,
|
||||
authorizationEndpoint: clientInfo.config.serverMetadata().authorization_endpoint,
|
||||
tokenEndpoint: clientInfo.config.serverMetadata().token_endpoint,
|
||||
});
|
||||
|
||||
if (persistedConfig) {
|
||||
serverConfig.oauth = {
|
||||
...(serverConfig.oauth || {}),
|
||||
...persistedConfig.oauth,
|
||||
};
|
||||
}
|
||||
|
||||
return clientInfo;
|
||||
} catch (error) {
|
||||
console.error(`Failed to register OAuth client for server ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get authorization URL for user authorization (OAuth 2.0 authorization code flow)
|
||||
*/
|
||||
export const getAuthorizationUrl = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
clientInfo: RegisteredClientInfo,
|
||||
redirectUri: string,
|
||||
state: string,
|
||||
codeVerifier: string,
|
||||
): Promise<string> => {
|
||||
try {
|
||||
// Generate code challenge for PKCE (required by MCP spec)
|
||||
const codeChallenge = await client.calculatePKCECodeChallenge(codeVerifier);
|
||||
|
||||
// Build authorization parameters
|
||||
const params: Record<string, string> = {
|
||||
redirect_uri: redirectUri,
|
||||
state,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
scope: serverConfig.oauth?.scopes?.join(' ') || 'openid',
|
||||
};
|
||||
|
||||
// Add resource parameter for MCP (RFC8707)
|
||||
if (serverConfig.oauth?.resource) {
|
||||
params.resource = serverConfig.oauth.resource;
|
||||
}
|
||||
|
||||
const authUrl = client.buildAuthorizationUrl(clientInfo.config, params);
|
||||
return authUrl.toString();
|
||||
} catch (error) {
|
||||
console.error(`Failed to generate authorization URL for server ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Exchange authorization code for access token
|
||||
*/
|
||||
export const exchangeCodeForToken = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
clientInfo: RegisteredClientInfo,
|
||||
currentUrl: string,
|
||||
codeVerifier: string,
|
||||
): Promise<{ accessToken: string; refreshToken?: string; expiresIn?: number }> => {
|
||||
try {
|
||||
console.log(`Exchanging authorization code for access token for server: ${serverName}`);
|
||||
|
||||
// Prepare token endpoint parameters
|
||||
const tokenParams: Record<string, string> = {
|
||||
code_verifier: codeVerifier,
|
||||
};
|
||||
|
||||
// Add resource parameter for MCP (RFC8707)
|
||||
if (serverConfig.oauth?.resource) {
|
||||
tokenParams.resource = serverConfig.oauth.resource;
|
||||
}
|
||||
|
||||
const tokens = await client.authorizationCodeGrant(
|
||||
clientInfo.config,
|
||||
new URL(currentUrl),
|
||||
{ expectedState: undefined }, // State is already validated
|
||||
tokenParams,
|
||||
);
|
||||
|
||||
console.log(`Successfully obtained access token for server: ${serverName}`);
|
||||
|
||||
await persistTokens(serverName, {
|
||||
accessToken: tokens.access_token,
|
||||
refreshToken: tokens.refresh_token ?? undefined,
|
||||
});
|
||||
|
||||
return {
|
||||
accessToken: tokens.access_token,
|
||||
refreshToken: tokens.refresh_token,
|
||||
expiresIn: tokens.expires_in,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error(`Failed to exchange code for token for server ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Refresh access token using refresh token
|
||||
*/
|
||||
export const refreshAccessToken = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
clientInfo: RegisteredClientInfo,
|
||||
refreshToken: string,
|
||||
): Promise<{ accessToken: string; refreshToken?: string; expiresIn?: number }> => {
|
||||
try {
|
||||
console.log(`Refreshing access token for server: ${serverName}`);
|
||||
|
||||
// Prepare refresh token parameters
|
||||
const params: Record<string, string> = {};
|
||||
|
||||
// Add resource parameter for MCP (RFC8707)
|
||||
if (serverConfig.oauth?.resource) {
|
||||
params.resource = serverConfig.oauth.resource;
|
||||
}
|
||||
|
||||
const tokens = await client.refreshTokenGrant(clientInfo.config, refreshToken, params);
|
||||
|
||||
console.log(`Successfully refreshed access token for server: ${serverName}`);
|
||||
|
||||
await persistTokens(serverName, {
|
||||
accessToken: tokens.access_token,
|
||||
refreshToken: tokens.refresh_token ?? undefined,
|
||||
});
|
||||
|
||||
return {
|
||||
accessToken: tokens.access_token,
|
||||
refreshToken: tokens.refresh_token,
|
||||
expiresIn: tokens.expires_in,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh access token for server ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier
|
||||
*/
|
||||
export const generateCodeVerifier = (): string => {
|
||||
return client.randomPKCECodeVerifier();
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculate PKCE code challenge from verifier
|
||||
*/
|
||||
export const calculateCodeChallenge = async (codeVerifier: string): Promise<string> => {
|
||||
return client.calculatePKCECodeChallenge(codeVerifier);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get registered client info from cache
|
||||
*/
|
||||
export const getRegisteredClient = (serverName: string): RegisteredClientInfo | undefined => {
|
||||
return registeredClients.get(serverName);
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize OAuth for a server (performs registration if needed)
|
||||
* Now supports auto-detection via 401 responses with WWW-Authenticate header
|
||||
*
|
||||
* @param serverName - Name of the server
|
||||
* @param serverConfig - Server configuration
|
||||
* @param autoDetectedIssuer - Optional issuer URL from auto-detection
|
||||
* @param autoDetectedScopes - Optional scopes from auto-detection
|
||||
* @returns RegisteredClientInfo or null
|
||||
*/
|
||||
export const initializeOAuthForServer = async (
|
||||
serverName: string,
|
||||
serverConfig: ServerConfig,
|
||||
autoDetectedIssuer?: string,
|
||||
autoDetectedScopes?: string[],
|
||||
): Promise<RegisteredClientInfo | null> => {
|
||||
if (!serverConfig.oauth) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if dynamic registration should be attempted
|
||||
const shouldAttemptRegistration =
|
||||
autoDetectedIssuer || // Auto-detected from 401 response
|
||||
serverConfig.oauth.dynamicRegistration?.enabled === true || // Explicitly enabled
|
||||
(serverConfig.oauth.dynamicRegistration && !serverConfig.oauth.clientId); // Configured but no static client
|
||||
|
||||
if (shouldAttemptRegistration) {
|
||||
try {
|
||||
// Perform dynamic client registration
|
||||
const clientInfo = await registerClient(
|
||||
serverName,
|
||||
serverConfig,
|
||||
autoDetectedIssuer,
|
||||
autoDetectedScopes,
|
||||
);
|
||||
return clientInfo;
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize OAuth for server ${serverName}:`, error);
|
||||
// If auto-detection failed, don't throw - allow fallback to static config
|
||||
if (!autoDetectedIssuer) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Static client configuration - create Configuration from static values
|
||||
if (serverConfig.oauth.clientId) {
|
||||
// Try to fetch and store scopes if not already configured
|
||||
if (!serverConfig.oauth.scopes && serverConfig.url) {
|
||||
try {
|
||||
const fetchedScopes = await fetchScopesFromServer(serverConfig.url);
|
||||
if (fetchedScopes && fetchedScopes.length > 0) {
|
||||
await mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
oauth.scopes = fetchedScopes;
|
||||
});
|
||||
|
||||
if (!serverConfig.oauth) {
|
||||
serverConfig.oauth = {};
|
||||
}
|
||||
serverConfig.oauth.scopes = fetchedScopes;
|
||||
console.log(`Stored fetched scopes for ${serverName}: ${fetchedScopes.join(', ')}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(`Failed to fetch scopes for ${serverName}, will use defaults`);
|
||||
}
|
||||
}
|
||||
|
||||
// For static config, we need the authorization server URL
|
||||
let serverUrl: URL;
|
||||
|
||||
if (serverConfig.oauth.authorizationEndpoint) {
|
||||
const authUrl = new URL(serverConfig.oauth.authorizationEndpoint!);
|
||||
serverUrl = new URL(`${authUrl.protocol}//${authUrl.host}`);
|
||||
} else if (serverConfig.oauth.tokenEndpoint) {
|
||||
const tokenUrl = new URL(serverConfig.oauth.tokenEndpoint!);
|
||||
serverUrl = new URL(`${tokenUrl.protocol}//${tokenUrl.host}`);
|
||||
} else {
|
||||
console.warn(`Server ${serverName} has static OAuth config but missing endpoints`);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
// Discover the server configuration
|
||||
const clientAuth = serverConfig.oauth.clientSecret
|
||||
? client.ClientSecretPost(serverConfig.oauth.clientSecret)
|
||||
: client.None();
|
||||
|
||||
const config = await client.discovery(
|
||||
serverUrl,
|
||||
serverConfig.oauth.clientId!,
|
||||
undefined,
|
||||
clientAuth,
|
||||
);
|
||||
|
||||
const clientInfo: RegisteredClientInfo = {
|
||||
config,
|
||||
clientId: serverConfig.oauth.clientId!,
|
||||
clientSecret: serverConfig.oauth.clientSecret,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
registeredClients.set(serverName, clientInfo);
|
||||
return clientInfo;
|
||||
} catch (error) {
|
||||
console.error(`Failed to discover OAuth server for ${serverName}:`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
271
src/services/oauthService.ts
Normal file
271
src/services/oauthService.ts
Normal file
@@ -0,0 +1,271 @@
|
||||
import { ProxyOAuthServerProvider } from '@modelcontextprotocol/sdk/server/auth/providers/proxyProvider.js';
|
||||
import { mcpAuthRouter } from '@modelcontextprotocol/sdk/server/auth/router.js';
|
||||
import { RequestHandler } from 'express';
|
||||
import { loadSettings } from '../config/index.js';
|
||||
import { initializeOAuthForServer, refreshAccessToken } from './oauthClientRegistration.js';
|
||||
|
||||
// Re-export for external use
|
||||
export {
|
||||
getRegisteredClient,
|
||||
getAuthorizationUrl,
|
||||
exchangeCodeForToken,
|
||||
generateCodeVerifier,
|
||||
calculateCodeChallenge,
|
||||
autoDetectOAuthConfig,
|
||||
parseWWWAuthenticateHeader,
|
||||
fetchProtectedResourceMetadata,
|
||||
} from './oauthClientRegistration.js';
|
||||
|
||||
let oauthProvider: ProxyOAuthServerProvider | null = null;
|
||||
let oauthRouter: RequestHandler | null = null;
|
||||
|
||||
/**
|
||||
* Initialize OAuth provider from system configuration
|
||||
*/
|
||||
export const initOAuthProvider = (): void => {
|
||||
const settings = loadSettings();
|
||||
const oauthConfig = settings.systemConfig?.oauth;
|
||||
|
||||
if (!oauthConfig || !oauthConfig.enabled) {
|
||||
console.log('OAuth provider is disabled or not configured');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Create proxy OAuth provider
|
||||
oauthProvider = new ProxyOAuthServerProvider({
|
||||
endpoints: {
|
||||
authorizationUrl: oauthConfig.endpoints.authorizationUrl,
|
||||
tokenUrl: oauthConfig.endpoints.tokenUrl,
|
||||
revocationUrl: oauthConfig.endpoints.revocationUrl,
|
||||
},
|
||||
verifyAccessToken: async (token: string) => {
|
||||
// If a verification endpoint is configured, use it
|
||||
if (oauthConfig.verifyAccessToken?.endpoint) {
|
||||
const response = await fetch(oauthConfig.verifyAccessToken.endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...oauthConfig.verifyAccessToken.headers,
|
||||
},
|
||||
body: JSON.stringify({ token }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Token verification failed: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
return {
|
||||
token,
|
||||
clientId: result.client_id || result.clientId || 'unknown',
|
||||
scopes: result.scopes || result.scope?.split(' ') || [],
|
||||
};
|
||||
}
|
||||
|
||||
// Default verification - just extract basic info from token
|
||||
// In production, you should decode/verify JWT or call an introspection endpoint
|
||||
return {
|
||||
token,
|
||||
clientId: 'default',
|
||||
scopes: oauthConfig.scopesSupported || [],
|
||||
};
|
||||
},
|
||||
getClient: async (clientId: string) => {
|
||||
// Find client in configuration
|
||||
const client = oauthConfig.clients?.find((c) => c.client_id === clientId);
|
||||
|
||||
if (!client) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
client_id: client.client_id,
|
||||
redirect_uris: client.redirect_uris,
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// Create OAuth router
|
||||
const issuerUrl = new URL(oauthConfig.issuerUrl);
|
||||
const baseUrl = oauthConfig.baseUrl ? new URL(oauthConfig.baseUrl) : issuerUrl;
|
||||
|
||||
oauthRouter = mcpAuthRouter({
|
||||
provider: oauthProvider,
|
||||
issuerUrl,
|
||||
baseUrl,
|
||||
serviceDocumentationUrl: oauthConfig.serviceDocumentationUrl
|
||||
? new URL(oauthConfig.serviceDocumentationUrl)
|
||||
: undefined,
|
||||
scopesSupported: oauthConfig.scopesSupported,
|
||||
});
|
||||
|
||||
console.log('OAuth provider initialized successfully');
|
||||
console.log(`OAuth issuer URL: ${issuerUrl.origin}`);
|
||||
// Only log endpoint URLs, not full config which might contain sensitive data
|
||||
console.log(
|
||||
'OAuth endpoints configured: authorization, token' +
|
||||
(oauthConfig.endpoints.revocationUrl ? ', revocation' : ''),
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize OAuth provider:', error);
|
||||
oauthProvider = null;
|
||||
oauthRouter = null;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the OAuth router if available
|
||||
*/
|
||||
export const getOAuthRouter = (): RequestHandler | null => {
|
||||
return oauthRouter;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the OAuth provider if available
|
||||
*/
|
||||
export const getOAuthProvider = (): ProxyOAuthServerProvider | null => {
|
||||
return oauthProvider;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if OAuth is enabled
|
||||
*/
|
||||
export const isOAuthEnabled = (): boolean => {
|
||||
return oauthProvider !== null && oauthRouter !== null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get OAuth access token for a server if configured
|
||||
* Handles both static tokens and dynamic OAuth flows with automatic token refresh
|
||||
*/
|
||||
export const getServerOAuthToken = async (serverName: string): Promise<string | undefined> => {
|
||||
const settings = loadSettings();
|
||||
const serverConfig = settings.mcpServers[serverName];
|
||||
|
||||
if (!serverConfig?.oauth) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// If a pre-configured access token exists, use it
|
||||
if (serverConfig.oauth.accessToken) {
|
||||
// TODO: In a production system, check if token is expired and refresh if needed
|
||||
// For now, just return the configured token
|
||||
return serverConfig.oauth.accessToken;
|
||||
}
|
||||
|
||||
// If dynamic registration is enabled, initialize OAuth and get token
|
||||
if (serverConfig.oauth.dynamicRegistration?.enabled) {
|
||||
try {
|
||||
// Initialize OAuth for this server (registers client if needed)
|
||||
const clientInfo = await initializeOAuthForServer(serverName, serverConfig);
|
||||
|
||||
if (!clientInfo) {
|
||||
console.warn(`Failed to initialize OAuth for server: ${serverName}`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// If we have a refresh token, try to get a new access token
|
||||
if (serverConfig.oauth.refreshToken) {
|
||||
try {
|
||||
const tokens = await refreshAccessToken(
|
||||
serverName,
|
||||
serverConfig,
|
||||
clientInfo,
|
||||
serverConfig.oauth.refreshToken,
|
||||
);
|
||||
return tokens.accessToken;
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh token for server ${serverName}:`, error);
|
||||
// Token refresh failed - user needs to re-authorize
|
||||
// In a production system, you would trigger a new authorization flow here
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// No access token and no refresh token available
|
||||
// User needs to go through the authorization flow
|
||||
// This would typically be triggered by an API endpoint that initiates the OAuth flow
|
||||
console.log(`Server ${serverName} requires user authorization via OAuth flow`);
|
||||
return undefined;
|
||||
} catch (error) {
|
||||
console.error(`Failed to get OAuth token for server ${serverName}:`, error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// Static client configuration - check for existing token
|
||||
if (serverConfig.oauth.clientId && serverConfig.oauth.accessToken) {
|
||||
return serverConfig.oauth.accessToken;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
|
||||
/**
|
||||
* Add OAuth authorization header to request headers if token is available
|
||||
*/
|
||||
export const addOAuthHeader = async (
|
||||
serverName: string,
|
||||
headers: Record<string, string>,
|
||||
): Promise<Record<string, string>> => {
|
||||
const token = await getServerOAuthToken(serverName);
|
||||
|
||||
if (token) {
|
||||
return {
|
||||
...headers,
|
||||
Authorization: `Bearer ${token}`,
|
||||
};
|
||||
}
|
||||
|
||||
return headers;
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize OAuth for all configured servers with explicit dynamic registration enabled
|
||||
* Servers without explicit configuration will be registered on-demand when receiving 401
|
||||
* Call this at application startup to pre-register known OAuth servers
|
||||
*/
|
||||
export const initializeAllOAuthClients = async (): Promise<void> => {
|
||||
const settings = loadSettings();
|
||||
|
||||
console.log('Initializing OAuth clients for explicitly configured servers...');
|
||||
|
||||
const serverNames = Object.keys(settings.mcpServers);
|
||||
const registrationPromises: Promise<void>[] = [];
|
||||
|
||||
for (const serverName of serverNames) {
|
||||
const serverConfig = settings.mcpServers[serverName];
|
||||
|
||||
// Only initialize servers with explicitly enabled dynamic registration
|
||||
// Others will be auto-detected and registered on first 401 response
|
||||
if (serverConfig.oauth?.dynamicRegistration?.enabled === true) {
|
||||
registrationPromises.push(
|
||||
initializeOAuthForServer(serverName, serverConfig)
|
||||
.then((clientInfo) => {
|
||||
if (clientInfo) {
|
||||
console.log(`✓ OAuth client pre-registered for server: ${serverName}`);
|
||||
} else {
|
||||
console.warn(`✗ Failed to pre-register OAuth client for server: ${serverName}`);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`✗ Error pre-registering OAuth client for server ${serverName}:`,
|
||||
error.message,
|
||||
);
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all registrations to complete
|
||||
if (registrationPromises.length > 0) {
|
||||
await Promise.all(registrationPromises);
|
||||
console.log(
|
||||
`OAuth client pre-registration completed for ${registrationPromises.length} server(s)`,
|
||||
);
|
||||
} else {
|
||||
console.log('No servers configured for pre-registration (will auto-detect on 401 responses)');
|
||||
}
|
||||
};
|
||||
158
src/services/oauthSettingsStore.ts
Normal file
158
src/services/oauthSettingsStore.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
import { loadSettings, saveSettings } from '../config/index.js';
|
||||
import { McpSettings, ServerConfig } from '../types/index.js';
|
||||
|
||||
type OAuthConfig = NonNullable<ServerConfig['oauth']>;
|
||||
export type ServerConfigWithOAuth = ServerConfig & { oauth: OAuthConfig };
|
||||
|
||||
export interface OAuthSettingsContext {
|
||||
settings: McpSettings;
|
||||
serverConfig: ServerConfig;
|
||||
oauth: OAuthConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the latest server configuration from disk.
|
||||
*/
|
||||
export const loadServerConfig = (serverName: string): ServerConfig | undefined => {
|
||||
const settings = loadSettings();
|
||||
return settings.mcpServers?.[serverName];
|
||||
};
|
||||
|
||||
/**
|
||||
* Mutate OAuth configuration for a server and persist the updated settings.
|
||||
* The mutator receives the shared settings object to allow related updates when needed.
|
||||
*/
|
||||
export const mutateOAuthSettings = async (
|
||||
serverName: string,
|
||||
mutator: (context: OAuthSettingsContext) => void,
|
||||
): Promise<ServerConfigWithOAuth | undefined> => {
|
||||
const settings = loadSettings();
|
||||
const serverConfig = settings.mcpServers?.[serverName];
|
||||
|
||||
if (!serverConfig) {
|
||||
console.warn(`Server ${serverName} not found while updating OAuth settings`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (!serverConfig.oauth) {
|
||||
serverConfig.oauth = {};
|
||||
}
|
||||
|
||||
const context: OAuthSettingsContext = {
|
||||
settings,
|
||||
serverConfig,
|
||||
oauth: serverConfig.oauth,
|
||||
};
|
||||
|
||||
mutator(context);
|
||||
|
||||
const saved = saveSettings(settings);
|
||||
if (!saved) {
|
||||
throw new Error(`Failed to persist OAuth settings for server ${serverName}`);
|
||||
}
|
||||
|
||||
return context.serverConfig as ServerConfigWithOAuth;
|
||||
};
|
||||
|
||||
export const persistClientCredentials = async (
|
||||
serverName: string,
|
||||
credentials: {
|
||||
clientId: string;
|
||||
clientSecret?: string;
|
||||
scopes?: string[];
|
||||
authorizationEndpoint?: string;
|
||||
tokenEndpoint?: string;
|
||||
},
|
||||
): Promise<ServerConfigWithOAuth | undefined> => {
|
||||
const updated = await mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
oauth.clientId = credentials.clientId;
|
||||
oauth.clientSecret = credentials.clientSecret;
|
||||
|
||||
if (credentials.scopes && credentials.scopes.length > 0) {
|
||||
oauth.scopes = credentials.scopes;
|
||||
}
|
||||
if (credentials.authorizationEndpoint) {
|
||||
oauth.authorizationEndpoint = credentials.authorizationEndpoint;
|
||||
}
|
||||
if (credentials.tokenEndpoint) {
|
||||
oauth.tokenEndpoint = credentials.tokenEndpoint;
|
||||
}
|
||||
});
|
||||
|
||||
console.log(`Persisted OAuth client credentials for server: ${serverName}`);
|
||||
if (credentials.scopes && credentials.scopes.length > 0) {
|
||||
console.log(`Stored OAuth scopes for ${serverName}: ${credentials.scopes.join(', ')}`);
|
||||
}
|
||||
|
||||
return updated;
|
||||
};
|
||||
|
||||
/**
|
||||
* Persist OAuth tokens and optionally replace the stored refresh token.
|
||||
*/
|
||||
export const persistTokens = async (
|
||||
serverName: string,
|
||||
tokens: {
|
||||
accessToken: string;
|
||||
refreshToken?: string | null;
|
||||
clearPendingAuthorization?: boolean;
|
||||
},
|
||||
): Promise<ServerConfigWithOAuth | undefined> => {
|
||||
return mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
oauth.accessToken = tokens.accessToken;
|
||||
|
||||
if (tokens.refreshToken !== undefined) {
|
||||
if (tokens.refreshToken) {
|
||||
oauth.refreshToken = tokens.refreshToken;
|
||||
} else {
|
||||
delete oauth.refreshToken;
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens.clearPendingAuthorization && oauth.pendingAuthorization) {
|
||||
delete oauth.pendingAuthorization;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Update or create a pending authorization record.
|
||||
*/
|
||||
export const updatePendingAuthorization = async (
|
||||
serverName: string,
|
||||
pending: Partial<NonNullable<OAuthConfig['pendingAuthorization']>>,
|
||||
): Promise<ServerConfigWithOAuth | undefined> => {
|
||||
return mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
oauth.pendingAuthorization = {
|
||||
...(oauth.pendingAuthorization || {}),
|
||||
...pending,
|
||||
createdAt: pending.createdAt ?? Date.now(),
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Clear cached OAuth data using shared helpers.
|
||||
*/
|
||||
export const clearOAuthData = async (
|
||||
serverName: string,
|
||||
scope: 'all' | 'client' | 'tokens' | 'verifier',
|
||||
): Promise<ServerConfigWithOAuth | undefined> => {
|
||||
return mutateOAuthSettings(serverName, ({ oauth }) => {
|
||||
if (scope === 'tokens' || scope === 'all') {
|
||||
delete oauth.accessToken;
|
||||
delete oauth.refreshToken;
|
||||
}
|
||||
|
||||
if (scope === 'client' || scope === 'all') {
|
||||
delete oauth.clientId;
|
||||
delete oauth.clientSecret;
|
||||
}
|
||||
|
||||
if (scope === 'verifier' || scope === 'all') {
|
||||
if (oauth.pendingAuthorization) {
|
||||
delete oauth.pendingAuthorization;
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user