mirror of
https://github.com/samanhappy/mcphub.git
synced 2025-12-24 02:39:19 -05:00
389 lines
14 KiB
TypeScript
389 lines
14 KiB
TypeScript
/**
|
|
* OAuth Callback Controller
|
|
*
|
|
* Handles OAuth 2.0 authorization callbacks for upstream MCP servers.
|
|
*
|
|
* This controller implements a simplified callback flow that relies on the MCP SDK
|
|
* to handle the complete OAuth token exchange:
|
|
*
|
|
* 1. Extract authorization code from callback URL
|
|
* 2. Find the corresponding server using the state parameter
|
|
* 3. Store the authorization code temporarily
|
|
* 4. Reconnect the server - SDK's auth() function will:
|
|
* - Automatically discover OAuth endpoints
|
|
* - Exchange the code for tokens using PKCE
|
|
* - Save tokens via our OAuthClientProvider.saveTokens()
|
|
*/
|
|
|
|
import { Request, Response } from 'express';
|
|
import {
|
|
getServerByName,
|
|
getServerByOAuthState,
|
|
createTransportFromConfig,
|
|
} from '../services/mcpService.js';
|
|
import { getNameSeparator, loadSettings } from '../config/index.js';
|
|
import type { ServerInfo } from '../types/index.js';
|
|
|
|
/**
|
|
* Generate HTML response page with i18n support
|
|
*/
|
|
const generateHtmlResponse = (
|
|
type: 'error' | 'success',
|
|
title: string,
|
|
message: string,
|
|
details?: { label: string; value: string }[],
|
|
autoClose: boolean = false,
|
|
): string => {
|
|
const backgroundColor = type === 'error' ? '#fee' : '#efe';
|
|
const borderColor = type === 'error' ? '#fcc' : '#cfc';
|
|
const titleColor = type === 'error' ? '#c33' : '#3c3';
|
|
const buttonColor = type === 'error' ? '#c33' : '#3c3';
|
|
|
|
return `
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>${title}</title>
|
|
<style>
|
|
body { font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }
|
|
.container { background-color: ${backgroundColor}; border: 1px solid ${borderColor}; padding: 20px; border-radius: 8px; }
|
|
h1 { color: ${titleColor}; margin-top: 0; }
|
|
.detail { margin-top: 10px; padding: 10px; background: #f9f9f9; border-radius: 4px; ${type === 'error' ? 'font-family: monospace; font-size: 12px; white-space: pre-wrap;' : ''} }
|
|
.close-btn { margin-top: 20px; padding: 10px 20px; background: ${buttonColor}; color: white; border: none; border-radius: 4px; cursor: pointer; }
|
|
</style>
|
|
${autoClose ? '<script>setTimeout(() => { window.close(); }, 3000);</script>' : ''}
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<h1>${type === 'success' ? '✓ ' : ''}${title}</h1>
|
|
${details ? details.map((d) => `<div class="detail"><strong>${d.label}:</strong> ${d.value}</div>`).join('') : ''}
|
|
<p>${message}</p>
|
|
${autoClose ? '<p>This window will close automatically in 3 seconds...</p>' : ''}
|
|
<button class="close-btn" onclick="window.close()">${autoClose ? 'Close Now' : 'Close Window'}</button>
|
|
</div>
|
|
</body>
|
|
</html>
|
|
`;
|
|
};
|
|
|
|
const normalizeQueryParam = (value: unknown): string | undefined => {
|
|
if (typeof value === 'string') {
|
|
return value;
|
|
}
|
|
|
|
if (Array.isArray(value) && value.length > 0) {
|
|
const [first] = value;
|
|
return typeof first === 'string' ? first : undefined;
|
|
}
|
|
|
|
return undefined;
|
|
};
|
|
|
|
const extractServerNameFromState = (stateValue: string): string | undefined => {
|
|
try {
|
|
const normalized = stateValue.replace(/-/g, '+').replace(/_/g, '/');
|
|
const padding = (4 - (normalized.length % 4)) % 4;
|
|
const base64 = normalized + '='.repeat(padding);
|
|
const decoded = Buffer.from(base64, 'base64').toString('utf8');
|
|
const payload = JSON.parse(decoded);
|
|
|
|
if (payload && typeof payload.server === 'string') {
|
|
return payload.server;
|
|
}
|
|
} catch (error) {
|
|
// Ignore decoding errors and fall back to delimiter-based parsing
|
|
}
|
|
|
|
const separatorIndex = stateValue.indexOf(':');
|
|
if (separatorIndex > 0) {
|
|
return stateValue.slice(0, separatorIndex);
|
|
}
|
|
|
|
return undefined;
|
|
};
|
|
|
|
/**
|
|
* Handle OAuth callback after user authorization
|
|
*
|
|
* This endpoint receives the authorization code from the OAuth provider
|
|
* and initiates the server reconnection process.
|
|
*
|
|
* Expected query parameters:
|
|
* - code: Authorization code from OAuth provider
|
|
* - state: Encoded server identifier used for OAuth session validation
|
|
* - error: Optional error code if authorization failed
|
|
* - error_description: Optional error description
|
|
*/
|
|
export const handleOAuthCallback = async (req: Request, res: Response) => {
|
|
try {
|
|
const { code, state, error, error_description } = req.query;
|
|
const codeParam = normalizeQueryParam(code);
|
|
const stateParam = normalizeQueryParam(state);
|
|
|
|
// Get translation function from request (set by i18n middleware)
|
|
const t = (req as any).t || ((key: string) => key);
|
|
|
|
// Check for authorization errors
|
|
if (error) {
|
|
console.error(`OAuth authorization failed: ${error} - ${error_description || ''}`);
|
|
return res.status(400).send(
|
|
generateHtmlResponse('error', t('oauthCallback.authorizationFailed'), '', [
|
|
{ label: t('oauthCallback.authorizationFailedError'), value: String(error) },
|
|
...(error_description
|
|
? [
|
|
{
|
|
label: t('oauthCallback.authorizationFailedDetails'),
|
|
value: String(error_description),
|
|
},
|
|
]
|
|
: []),
|
|
]),
|
|
);
|
|
}
|
|
|
|
// Validate required parameters
|
|
if (!stateParam) {
|
|
console.error('OAuth callback missing state parameter');
|
|
return res
|
|
.status(400)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.invalidRequest'),
|
|
t('oauthCallback.missingStateParameter'),
|
|
),
|
|
);
|
|
}
|
|
|
|
if (!codeParam) {
|
|
console.error('OAuth callback missing authorization code');
|
|
return res
|
|
.status(400)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.invalidRequest'),
|
|
t('oauthCallback.missingCodeParameter'),
|
|
),
|
|
);
|
|
}
|
|
|
|
console.log(`OAuth callback received - code: present, state: ${stateParam}`);
|
|
|
|
// Find server by state parameter
|
|
let serverInfo: ServerInfo | undefined;
|
|
|
|
serverInfo = getServerByOAuthState(stateParam);
|
|
|
|
let decodedServerName: string | undefined;
|
|
if (!serverInfo) {
|
|
decodedServerName = extractServerNameFromState(stateParam);
|
|
if (decodedServerName) {
|
|
console.log(`State lookup failed; decoding server name from state: ${decodedServerName}`);
|
|
serverInfo = getServerByName(decodedServerName);
|
|
}
|
|
}
|
|
|
|
if (!serverInfo) {
|
|
console.error(
|
|
`No server found for OAuth callback. State: ${stateParam}${
|
|
decodedServerName ? `, decoded server: ${decodedServerName}` : ''
|
|
}`,
|
|
);
|
|
return res
|
|
.status(400)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.serverNotFound'),
|
|
`${t('oauthCallback.serverNotFoundMessage')}\n${t('oauthCallback.sessionExpiredMessage')}`,
|
|
),
|
|
);
|
|
}
|
|
|
|
// Optional: Validate state parameter for additional security
|
|
if (serverInfo.oauth?.state && serverInfo.oauth.state !== stateParam) {
|
|
console.warn(
|
|
`State mismatch for server ${serverInfo.name}. Expected: ${serverInfo.oauth.state}, Got: ${stateParam}`,
|
|
);
|
|
// Note: We log a warning but don't fail the request since we have server name as primary identifier
|
|
}
|
|
|
|
console.log(`Processing OAuth callback for server: ${serverInfo.name}`);
|
|
|
|
// For StreamableHTTPClientTransport, we need to call finishAuth() on the transport
|
|
// This will exchange the authorization code for tokens automatically
|
|
if (serverInfo.transport && 'finishAuth' in serverInfo.transport) {
|
|
try {
|
|
console.log(`Calling transport.finishAuth() for server: ${serverInfo.name}`);
|
|
const currentTransport = serverInfo.transport as any;
|
|
await currentTransport.finishAuth(codeParam);
|
|
|
|
console.log(`Successfully exchanged authorization code for tokens: ${serverInfo.name}`);
|
|
|
|
// Refresh server configuration from disk to ensure we pick up newly saved tokens
|
|
const settings = loadSettings();
|
|
const storedConfig = settings.mcpServers?.[serverInfo.name];
|
|
const effectiveConfig = storedConfig || serverInfo.config;
|
|
|
|
if (!effectiveConfig) {
|
|
throw new Error(
|
|
`Missing server configuration for ${serverInfo.name} after OAuth callback`,
|
|
);
|
|
}
|
|
|
|
// Keep latest configuration cached on serverInfo
|
|
serverInfo.config = effectiveConfig;
|
|
|
|
// Ensure we have up-to-date request options for the reconnect attempt
|
|
if (!serverInfo.options) {
|
|
const requestConfig = effectiveConfig.options || {};
|
|
serverInfo.options = {
|
|
timeout: requestConfig.timeout || 60000,
|
|
resetTimeoutOnProgress: requestConfig.resetTimeoutOnProgress || false,
|
|
maxTotalTimeout: requestConfig.maxTotalTimeout,
|
|
};
|
|
}
|
|
|
|
// Replace the existing transport instance to avoid reusing a closed/aborted transport
|
|
try {
|
|
if (serverInfo.transport && 'close' in serverInfo.transport) {
|
|
await (serverInfo.transport as any).close();
|
|
}
|
|
} catch (closeError) {
|
|
console.warn(`Failed to close existing transport for ${serverInfo.name}:`, closeError);
|
|
}
|
|
|
|
console.log(
|
|
`Rebuilding transport with refreshed credentials for server: ${serverInfo.name}`,
|
|
);
|
|
const refreshedTransport = await createTransportFromConfig(
|
|
serverInfo.name,
|
|
effectiveConfig,
|
|
);
|
|
serverInfo.transport = refreshedTransport;
|
|
|
|
// Update server status to indicate OAuth is complete
|
|
serverInfo.status = 'connected';
|
|
if (serverInfo.oauth) {
|
|
serverInfo.oauth.authorizationUrl = undefined;
|
|
serverInfo.oauth.state = undefined;
|
|
serverInfo.oauth.codeVerifier = undefined;
|
|
}
|
|
|
|
// Check if client needs to be connected
|
|
const isClientConnected = serverInfo.client && serverInfo.client.getServerCapabilities();
|
|
|
|
if (!isClientConnected) {
|
|
// Client is not connected yet, connect it
|
|
if (serverInfo.client && serverInfo.transport) {
|
|
console.log(`Connecting client with refreshed transport for: ${serverInfo.name}`);
|
|
try {
|
|
await serverInfo.client.connect(serverInfo.transport, serverInfo.options);
|
|
console.log(`Client connected successfully for: ${serverInfo.name}`);
|
|
|
|
// List tools after successful connection
|
|
const capabilities = serverInfo.client.getServerCapabilities();
|
|
console.log(
|
|
`Server capabilities for ${serverInfo.name}:`,
|
|
JSON.stringify(capabilities),
|
|
);
|
|
|
|
if (capabilities?.tools) {
|
|
console.log(`Listing tools for server: ${serverInfo.name}`);
|
|
const toolsResult = await serverInfo.client.listTools({}, serverInfo.options);
|
|
const separator = getNameSeparator();
|
|
serverInfo.tools = toolsResult.tools.map((tool) => ({
|
|
name: `${serverInfo.name}${separator}${tool.name}`,
|
|
description: tool.description || '',
|
|
inputSchema: tool.inputSchema || {},
|
|
}));
|
|
console.log(
|
|
`Listed ${serverInfo.tools.length} tools for server: ${serverInfo.name}`,
|
|
);
|
|
} else {
|
|
console.log(`Server ${serverInfo.name} does not support tools capability`);
|
|
}
|
|
} catch (connectError) {
|
|
console.error(`Error connecting client for ${serverInfo.name}:`, connectError);
|
|
if (connectError instanceof Error) {
|
|
console.error(
|
|
`Connect error details for ${serverInfo.name}: ${connectError.message}`,
|
|
connectError.stack,
|
|
);
|
|
}
|
|
// Even if connection fails, mark OAuth as complete
|
|
// The user can try reconnecting from the dashboard
|
|
}
|
|
} else {
|
|
console.log(
|
|
`Cannot connect client for ${serverInfo.name}: client or transport missing`,
|
|
);
|
|
}
|
|
} else {
|
|
console.log(`Client already connected for server: ${serverInfo.name}`);
|
|
}
|
|
|
|
console.log(`Successfully completed OAuth flow for server: ${serverInfo.name}`);
|
|
|
|
// Return success page
|
|
return res.status(200).send(
|
|
generateHtmlResponse(
|
|
'success',
|
|
t('oauthCallback.authorizationSuccessful'),
|
|
`${t('oauthCallback.successMessage')}\n${t('oauthCallback.autoCloseMessage')}`,
|
|
[
|
|
{ label: t('oauthCallback.server'), value: serverInfo.name },
|
|
{ label: t('oauthCallback.status'), value: t('oauthCallback.connected') },
|
|
],
|
|
true, // auto-close
|
|
),
|
|
);
|
|
} catch (error) {
|
|
console.error(`Failed to complete OAuth flow for server ${serverInfo.name}:`, error);
|
|
console.error(`Error type: ${typeof error}, Error name: ${error?.constructor?.name}`);
|
|
console.error(`Error message: ${error instanceof Error ? error.message : String(error)}`);
|
|
console.error(`Error stack:`, error instanceof Error ? error.stack : 'No stack trace');
|
|
|
|
return res
|
|
.status(500)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.connectionError'),
|
|
`${t('oauthCallback.connectionErrorMessage')}\n${t('oauthCallback.reconnectMessage')}`,
|
|
[{ label: '', value: error instanceof Error ? error.message : String(error) }],
|
|
),
|
|
);
|
|
}
|
|
} else {
|
|
// No transport available or transport doesn't support finishAuth
|
|
console.error(`Transport for server ${serverInfo.name} does not support finishAuth()`);
|
|
return res
|
|
.status(500)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.configurationError'),
|
|
t('oauthCallback.configurationErrorMessage'),
|
|
),
|
|
);
|
|
}
|
|
} catch (error) {
|
|
console.error('Unexpected error handling OAuth callback:', error);
|
|
|
|
// Get translation function from request (set by i18n middleware)
|
|
const t = (req as any).t || ((key: string) => key);
|
|
|
|
return res
|
|
.status(500)
|
|
.send(
|
|
generateHtmlResponse(
|
|
'error',
|
|
t('oauthCallback.internalError'),
|
|
t('oauthCallback.internalErrorMessage'),
|
|
),
|
|
);
|
|
}
|
|
};
|