From a55405b974d55c8e9012749304258c2668d758db Mon Sep 17 00:00:00 2001 From: samanhappy Date: Sun, 26 Oct 2025 21:41:03 +0800 Subject: [PATCH] feat: add cluster routing support --- src/services/clusterService.ts | 457 +++++++++++++++++++++++++++++++++ src/services/mcpService.ts | 3 + src/services/sseService.ts | 18 ++ src/types/index.ts | 15 ++ tests/clusterService.test.ts | 67 +++++ 5 files changed, 560 insertions(+) create mode 100644 src/services/clusterService.ts create mode 100644 tests/clusterService.test.ts diff --git a/src/services/clusterService.ts b/src/services/clusterService.ts new file mode 100644 index 0000000..bb097a8 --- /dev/null +++ b/src/services/clusterService.ts @@ -0,0 +1,457 @@ +import { Request, Response } from 'express'; +import { URL } from 'url'; +import config, { loadSettings } from '../config/index.js'; +import { ClusterConfig, ClusterNodeConfig } from '../types/index.js'; + +interface ProxyContext { + node: ClusterNodeConfig; + targetUrl: URL; +} + +const sessionBindings = new Map(); +const groupCounters = new Map(); + +const DEFAULT_GROUP_KEY = '__default__'; + +const isIterableHeaderValue = (value: string | string[] | undefined): value is string[] => { + return Array.isArray(value); +}; + +const createHeadersFromRequest = (req: Request, node: ClusterNodeConfig): Headers => { + const headers = new Headers(); + for (const [key, rawValue] of Object.entries(req.headers)) { + if (rawValue === undefined) { + continue; + } + if (key.toLowerCase() === 'host') { + continue; + } + if (isIterableHeaderValue(rawValue)) { + for (const value of rawValue) { + headers.append(key, value); + } + } else { + headers.set(key, String(rawValue)); + } + } + + if (node.forwardHeaders) { + for (const [key, value] of Object.entries(node.forwardHeaders)) { + if (value !== undefined) { + headers.set(key, value); + } + } + } + + return headers; +}; + +const getClusterConfig = (): ClusterConfig | undefined => { + const settings = loadSettings(); + return settings.systemConfig?.cluster; +}; + +const getClusterNodes = (): ClusterNodeConfig[] => { + const config = getClusterConfig(); + if (!config?.enabled) { + return []; + } + return config.nodes ?? []; +}; + +const isClusterEnabled = (): boolean => { + return getClusterNodes().length > 0; +}; + +const sanitizePathSegment = (segment: string): string => { + return segment.replace(/^\/+/, '').replace(/\/+$/, ''); +}; + +const joinUrlPaths = (...segments: (string | undefined)[]): string => { + const sanitizedSegments = segments + .filter((segment): segment is string => segment !== undefined && segment !== null && segment !== '') + .map((segment) => sanitizePathSegment(segment)); + + if (!sanitizedSegments.length) { + return '/'; + } + + const joined = sanitizedSegments.filter((segment) => segment.length > 0).join('/'); + return joined ? `/${joined}` : '/'; +}; + +const normalizeBasePath = (path?: string): string => { + if (!path) { + return ''; + } + const normalized = path.startsWith('/') ? path : `/${path}`; + if (normalized === '/') { + return ''; + } + if (normalized !== '/' && normalized.endsWith('/')) { + return normalized.slice(0, -1); + } + return normalized; +}; + +const buildTargetUrl = (node: ClusterNodeConfig, originalUrl: string): URL => { + const placeholderBase = 'http://cluster.local'; + const requestUrl = new URL(originalUrl, placeholderBase); + const requestPath = requestUrl.pathname; + const hubBasePath = normalizeBasePath(config.basePath); + const relativePath = requestPath.startsWith(hubBasePath) + ? requestPath.slice(hubBasePath.length) || '/' + : requestPath; + + const nodePrefix = normalizeBasePath(node.pathPrefix ?? hubBasePath); + const targetUrl = new URL(node.url); + targetUrl.pathname = joinUrlPaths(targetUrl.pathname, nodePrefix, relativePath); + targetUrl.search = requestUrl.search; + targetUrl.hash = requestUrl.hash; + return targetUrl; +}; + +const matchesNodeGroup = (nodeGroup: string, targetGroup: string): boolean => { + if (!targetGroup) { + return nodeGroup === '' || nodeGroup === '*' || nodeGroup === 'global' || nodeGroup === 'default'; + } + + if (nodeGroup === '*') { + return true; + } + + return nodeGroup === targetGroup; +}; + +const selectNodeForGroup = (group?: string): ClusterNodeConfig | undefined => { + const nodes = getClusterNodes(); + if (!nodes.length) { + return undefined; + } + + const key = group ?? DEFAULT_GROUP_KEY; + const normalizedGroup = group ?? ''; + const candidates = nodes.filter((node) => { + if (!node.groups || node.groups.length === 0) { + return true; + } + + return node.groups.some((nodeGroup) => matchesNodeGroup(nodeGroup, normalizedGroup)); + }); + + if (!candidates.length) { + return undefined; + } + + const weightedCandidates: ClusterNodeConfig[] = []; + for (const candidate of candidates) { + const weight = Math.max(1, candidate.weight ?? 1); + for (let i = 0; i < weight; i += 1) { + weightedCandidates.push(candidate); + } + } + + const index = groupCounters.get(key) ?? 0; + const selected = weightedCandidates[index % weightedCandidates.length]; + groupCounters.set(key, index + 1); + return selected; +}; + +const bindSessionToNode = (sessionId: string, nodeId: string): void => { + sessionBindings.set(sessionId, nodeId); +}; + +const releaseSession = (sessionId: string): void => { + sessionBindings.delete(sessionId); +}; + +const getNodeForSession = (sessionId: string): ClusterNodeConfig | undefined => { + const nodeId = sessionBindings.get(sessionId); + if (!nodeId) { + return undefined; + } + return getClusterNodes().find((node) => node.id === nodeId); +}; + +const resolveProxyContext = (req: Request, group?: string, sessionId?: string): ProxyContext | undefined => { + if (!isClusterEnabled()) { + return undefined; + } + + if (sessionId) { + const node = getNodeForSession(sessionId); + if (node) { + return { node, targetUrl: buildTargetUrl(node, req.originalUrl) }; + } + } + + const node = selectNodeForGroup(group); + if (!node) { + return undefined; + } + + return { + node, + targetUrl: buildTargetUrl(node, req.originalUrl), + }; +}; + +const pipeReadableStreamToResponse = async ( + response: globalThis.Response, + res: Response, + onData?: (chunk: string) => void, +): Promise => { + if (!response.body) { + const text = await response.text(); + res.send(text); + return; + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + try { + let finished = false; + while (!finished) { + const { value, done } = await reader.read(); + finished = Boolean(done); + if (value) { + const chunkString = decoder.decode(value, { stream: true }); + if (onData) { + onData(chunkString); + } + res.write(Buffer.from(value)); + } + } + } catch (error) { + if ((error as Error).name !== 'AbortError') { + console.error('Cluster proxy stream error:', error); + } + } finally { + const finalChunk = decoder.decode(); + if (finalChunk && onData) { + onData(finalChunk); + } + res.end(); + } +}; + +const handleSseStream = async ( + node: ClusterNodeConfig, + req: Request, + res: Response, + context: ProxyContext, +): Promise => { + const controller = new AbortController(); + const sessionIds = new Set(); + req.on('close', () => { + controller.abort(); + for (const sessionId of sessionIds) { + releaseSession(sessionId); + } + }); + + let response: globalThis.Response; + try { + response = await fetch(context.targetUrl, { + method: 'GET', + headers: createHeadersFromRequest(req, node), + signal: controller.signal, + }); + } catch (error) { + console.error('Failed to proxy SSE request to cluster node:', error); + if (!res.headersSent) { + res.status(502).send('Failed to reach cluster node'); + } + for (const sessionId of sessionIds) { + releaseSession(sessionId); + } + return; + } + + res.status(response.status); + response.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (typeof res.flushHeaders === 'function') { + res.flushHeaders(); + } + + const isSse = response.headers.get('content-type')?.includes('text/event-stream'); + let buffer = ''; + await pipeReadableStreamToResponse( + response, + res, + isSse + ? (chunk) => { + buffer += chunk; + let boundaryIndex = buffer.indexOf('\n\n'); + while (boundaryIndex !== -1) { + const rawEvent = buffer.slice(0, boundaryIndex); + buffer = buffer.slice(boundaryIndex + 2); + const normalizedEvent = rawEvent.replace(/\r\n/g, '\n'); + const lines = normalizedEvent.split('\n'); + let eventName = ''; + let data = ''; + for (const line of lines) { + if (line.startsWith('event:')) { + eventName = line.slice(6).trim(); + } + if (line.startsWith('data:')) { + data += `${line.slice(5).trim()}`; + } + } + if (eventName === 'endpoint' && data) { + try { + const sessionUrl = new URL(data, 'http://localhost'); + const sessionId = sessionUrl.searchParams.get('sessionId'); + if (sessionId) { + bindSessionToNode(sessionId, node.id); + sessionIds.add(sessionId); + } + } catch (error) { + console.warn('Failed to parse session endpoint from cluster response:', error); + } + } + boundaryIndex = buffer.indexOf('\n\n'); + } + } + : undefined, + ); + + for (const sessionId of sessionIds) { + releaseSession(sessionId); + } +}; + +const forwardRequest = async ( + req: Request, + res: Response, + context: ProxyContext, + options?: { releaseSession?: string }, +): Promise => { + const { node, targetUrl } = context; + const method = req.method.toUpperCase(); + const init: RequestInit = { + method, + headers: createHeadersFromRequest(req, node), + }; + + if (method === 'POST' || method === 'PUT' || method === 'PATCH') { + if (req.body !== undefined) { + init.body = typeof req.body === 'string' ? req.body : JSON.stringify(req.body); + } + } + + const controller = new AbortController(); + init.signal = controller.signal; + req.on('close', () => { + controller.abort(); + }); + + let response: globalThis.Response; + try { + response = await fetch(targetUrl, init); + } catch (error) { + if ((error as Error).name !== 'AbortError') { + console.error('Failed to proxy request to cluster node:', error); + } + if (!res.headersSent) { + res.status(502).send('Failed to reach cluster node'); + } + if (options?.releaseSession) { + releaseSession(options.releaseSession); + } + return; + } + + const newSessionId = response.headers.get('mcp-session-id'); + if (newSessionId) { + bindSessionToNode(newSessionId, node.id); + } + + res.status(response.status); + response.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (response.headers.get('content-type')?.includes('text/event-stream')) { + await pipeReadableStreamToResponse(response, res); + } else { + const buffer = await response.arrayBuffer(); + if (buffer.byteLength === 0) { + res.end(); + } else { + res.send(Buffer.from(buffer)); + } + } + + if (options?.releaseSession) { + releaseSession(options.releaseSession); + } +}; + +export const tryProxySseConnection = async ( + req: Request, + res: Response, + group?: string, +): Promise => { + const context = resolveProxyContext(req, group); + if (!context) { + return false; + } + + await handleSseStream(context.node, req, res, context); + return true; +}; + +export const tryProxySseMessage = async (req: Request, res: Response): Promise => { + const sessionId = typeof req.query.sessionId === 'string' ? req.query.sessionId : undefined; + if (!sessionId) { + return false; + } + + const context = resolveProxyContext(req, undefined, sessionId); + if (!context) { + return false; + } + + await forwardRequest(req, res, context); + return true; +}; + +export const tryProxyMcpRequest = async ( + req: Request, + res: Response, + group?: string, +): Promise => { + const sessionIdHeader = req.headers['mcp-session-id']; + const sessionId = Array.isArray(sessionIdHeader) ? sessionIdHeader[0] : sessionIdHeader; + const context = resolveProxyContext(req, group, sessionId); + if (!context) { + return false; + } + + const releaseTarget = req.method.toUpperCase() === 'DELETE' ? sessionId : undefined; + await forwardRequest(req, res, context, { releaseSession: releaseTarget }); + return true; +}; + +export const clearClusterSessionBindings = (): void => { + sessionBindings.clear(); + groupCounters.clear(); +}; + +export const __clusterInternals = { + joinUrlPaths, + normalizeBasePath, + matchesNodeGroup, + buildTargetUrl, +}; diff --git a/src/services/mcpService.ts b/src/services/mcpService.ts index 85a6729..d5c156a 100644 --- a/src/services/mcpService.ts +++ b/src/services/mcpService.ts @@ -26,6 +26,7 @@ import { getDataService } from './services.js'; import { getServerDao, ServerConfigWithName } from '../dao/index.js'; import { initializeAllOAuthClients } from './oauthService.js'; import { createOAuthProvider } from './mcpOAuthProvider.js'; +import { clearClusterSessionBindings } from './clusterService.js'; const servers: { [sessionId: string]: Server } = {}; @@ -161,6 +162,8 @@ export const cleanupAllServers = (): void => { Object.keys(servers).forEach((sessionId) => { delete servers[sessionId]; }); + + clearClusterSessionBindings(); }; // Helper function to create transport based on server configuration diff --git a/src/services/sseService.ts b/src/services/sseService.ts index 05faa11..0687686 100644 --- a/src/services/sseService.ts +++ b/src/services/sseService.ts @@ -9,6 +9,7 @@ import { loadSettings } from '../config/index.js'; import config from '../config/index.js'; import { UserContextService } from './userContextService.js'; import { RequestContextService } from './requestContextService.js'; +import { tryProxyMcpRequest, tryProxySseConnection, tryProxySseMessage } from './clusterService.js'; const transports: { [sessionId: string]: { transport: Transport; group: string } } = {}; @@ -81,6 +82,10 @@ export const handleSseConnection = async (req: Request, res: Response): Promise< console.log(`Creating SSE transport with messages path: ${messagesPath}`); + if (await tryProxySseConnection(req, res, group)) { + return; + } + const transport = new SSEServerTransport(messagesPath, res); transports[transport.sessionId] = { transport, group: group }; @@ -117,6 +122,10 @@ export const handleSseMessage = async (req: Request, res: Response): Promise { return; } + const group = req.params.group; + if (await tryProxyMcpRequest(req, res, group)) { + return; + } + const sessionId = req.headers['mcp-session-id'] as string | undefined; if (!sessionId || !transports[sessionId]) { res.status(400).send('Invalid or missing session ID'); diff --git a/src/types/index.ts b/src/types/index.ts index 44baf23..9f64eb1 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -62,6 +62,20 @@ export interface MarketServerTool { inputSchema: Record; } +export interface ClusterNodeConfig { + id: string; // Unique identifier for the node + url: string; // Base URL for the node (e.g. http://node-a:3000) + groups?: string[]; // Optional list of group identifiers served by this node; include empty string for global routes + weight?: number; // Optional weight for load balancing + forwardHeaders?: Record; // Additional headers forwarded to the node on every request + pathPrefix?: string; // Optional prefix prepended before forwarding paths (defaults to hub base path) +} + +export interface ClusterConfig { + enabled?: boolean; // Flag to enable/disable cluster routing + nodes?: ClusterNodeConfig[]; // Cluster node definitions +} + export interface MarketServer { name: string; display_name: string; @@ -171,6 +185,7 @@ export interface SystemConfig { }; nameSeparator?: string; // Separator used between server name and tool/prompt name (default: '-') oauth?: OAuthProviderConfig; // OAuth provider configuration for upstream MCP servers + cluster?: ClusterConfig; // Cluster configuration for multi-node deployments } export interface UserConfig { diff --git a/tests/clusterService.test.ts b/tests/clusterService.test.ts new file mode 100644 index 0000000..1b5850d --- /dev/null +++ b/tests/clusterService.test.ts @@ -0,0 +1,67 @@ +import { ClusterNodeConfig } from '../src/types/index.js'; +import config from '../src/config/index.js'; +import { __clusterInternals } from '../src/services/clusterService.js'; + +const { buildTargetUrl, normalizeBasePath, matchesNodeGroup, joinUrlPaths } = __clusterInternals; + +describe('clusterService internals', () => { + const originalBasePath = config.basePath; + + afterEach(() => { + config.basePath = originalBasePath; + }); + + test('normalizeBasePath trims trailing slashes and enforces leading slash', () => { + expect(normalizeBasePath('')).toBe(''); + expect(normalizeBasePath('/')).toBe(''); + expect(normalizeBasePath('/api/')).toBe('/api'); + expect(normalizeBasePath('api')).toBe('/api'); + }); + + test('matchesNodeGroup recognises global shortcuts', () => { + expect(matchesNodeGroup('', '')).toBe(true); + expect(matchesNodeGroup('global', '')).toBe(true); + expect(matchesNodeGroup('default', '')).toBe(true); + expect(matchesNodeGroup('*', '')).toBe(true); + expect(matchesNodeGroup('*', 'group-a')).toBe(true); + expect(matchesNodeGroup('group-a', 'group-a')).toBe(true); + expect(matchesNodeGroup('group-a', 'group-b')).toBe(false); + }); + + test('joinUrlPaths combines segments without duplicating slashes', () => { + expect(joinUrlPaths('/', '/api', '/messages')).toBe('/api/messages'); + expect(joinUrlPaths('/root', '', '/')).toBe('/root'); + expect(joinUrlPaths('', '', '/tools')).toBe('/tools'); + }); + + test('buildTargetUrl respects hub base path and node prefix', () => { + config.basePath = '/hub'; + const node: ClusterNodeConfig = { + id: 'node-1', + url: 'http://backend:3000', + }; + const target = buildTargetUrl(node, '/hub/mcp/alpha?foo=bar'); + expect(target.toString()).toBe('http://backend:3000/hub/mcp/alpha?foo=bar'); + }); + + test('buildTargetUrl can override base path using node prefix', () => { + config.basePath = '/hub'; + const node: ClusterNodeConfig = { + id: 'node-1', + url: 'http://backend:3000', + pathPrefix: '/', + }; + const target = buildTargetUrl(node, '/hub/mcp/alpha?foo=bar'); + expect(target.toString()).toBe('http://backend:3000/mcp/alpha?foo=bar'); + }); + + test('buildTargetUrl appends to node URL path when provided', () => { + config.basePath = ''; + const node: ClusterNodeConfig = { + id: 'node-1', + url: 'http://backend:3000/root', + }; + const target = buildTargetUrl(node, '/messages?sessionId=123'); + expect(target.toString()).toBe('http://backend:3000/root/messages?sessionId=123'); + }); +});