Add prompt management functionality to MCP server (#281)

This commit is contained in:
samanhappy
2025-08-20 14:23:55 +08:00
committed by GitHub
parent 81c3091a5c
commit 6020611f57
15 changed files with 1247 additions and 44 deletions

View File

@@ -1,10 +1,16 @@
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import {
CallToolRequestSchema,
ListToolsRequestSchema,
ListPromptsRequestSchema,
GetPromptRequestSchema,
ServerCapabilities,
} from '@modelcontextprotocol/sdk/types.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ServerInfo, ServerConfig, ToolInfo } from '../types/index.js';
import { ServerInfo, ServerConfig, Tool } from '../types/index.js';
import { loadSettings, saveSettings, expandEnvVars, replaceEnvVars } from '../config/index.js';
import config from '../config/index.js';
import { getGroup } from './sseService.js';
@@ -343,6 +349,7 @@ export const initializeClientsFromSettings = async (
status: 'disconnected',
error: null,
tools: [],
prompts: [],
createTime: Date.now(),
enabled: false,
});
@@ -376,6 +383,7 @@ export const initializeClientsFromSettings = async (
status: 'disconnected',
error: 'Missing OpenAPI specification URL or schema',
tools: [],
prompts: [],
createTime: Date.now(),
});
continue;
@@ -388,6 +396,7 @@ export const initializeClientsFromSettings = async (
status: 'connecting',
error: null,
tools: [],
prompts: [],
createTime: Date.now(),
enabled: conf.enabled === undefined ? true : conf.enabled,
};
@@ -404,7 +413,7 @@ export const initializeClientsFromSettings = async (
// Convert OpenAPI tools to MCP tool format
const openApiTools = openApiClient.getTools();
const mcpTools: ToolInfo[] = openApiTools.map((tool) => ({
const mcpTools: Tool[] = openApiTools.map((tool) => ({
name: `${name}-${tool.name}`,
description: tool.description,
inputSchema: cleanInputSchema(tool.inputSchema),
@@ -469,6 +478,7 @@ export const initializeClientsFromSettings = async (
status: 'connecting',
error: null,
tools: [],
prompts: [],
client,
transport,
options: requestOptions,
@@ -480,32 +490,63 @@ export const initializeClientsFromSettings = async (
.connect(transport, initRequestOptions || requestOptions)
.then(() => {
console.log(`Successfully connected client for server: ${name}`);
client
.listTools({}, initRequestOptions || requestOptions)
.then((tools) => {
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
serverInfo.tools = tools.tools.map((tool) => ({
name: `${name}-${tool.name}`,
description: tool.description || '',
inputSchema: cleanInputSchema(tool.inputSchema || {}),
}));
serverInfo.status = 'connected';
serverInfo.error = null;
let dataError: Error | null = null;
if (capabilities?.tools) {
client
.listTools({}, initRequestOptions || requestOptions)
.then((tools) => {
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
serverInfo.tools = tools.tools.map((tool) => ({
name: `${name}-${tool.name}`,
description: tool.description || '',
inputSchema: cleanInputSchema(tool.inputSchema || {}),
}));
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
})
.catch((error) => {
console.error(
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
// Set up keep-alive ping for SSE connections
setupKeepAlive(serverInfo, conf);
if (capabilities?.prompts) {
client
.listPrompts({}, initRequestOptions || requestOptions)
.then((prompts) => {
console.log(
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
);
serverInfo.prompts = prompts.prompts.map((prompt) => ({
name: `${name}-${prompt.name}`,
title: prompt.title,
description: prompt.description,
arguments: prompt.arguments,
}));
})
.catch((error) => {
console.error(
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
);
dataError = error;
});
}
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
})
.catch((error) => {
console.error(
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
);
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to list tools: ${error.stack} `;
});
if (!dataError) {
serverInfo.status = 'connected';
serverInfo.error = null;
// Set up keep-alive ping for SSE connections
setupKeepAlive(serverInfo, conf);
} else {
serverInfo.status = 'disconnected';
serverInfo.error = `Failed to list data: ${dataError} `;
}
})
.catch((error) => {
console.error(
@@ -532,7 +573,7 @@ export const getServersInfo = (): Omit<ServerInfo, 'client' | 'transport'>[] =>
const filterServerInfos: ServerInfo[] = dataService.filterData
? dataService.filterData(serverInfos)
: serverInfos;
const infos = filterServerInfos.map(({ name, status, tools, createTime, error }) => {
const infos = filterServerInfos.map(({ name, status, tools, prompts, createTime, error }) => {
const serverConfig = settings.mcpServers[name];
const enabled = serverConfig ? serverConfig.enabled !== false : true;
@@ -546,11 +587,21 @@ export const getServersInfo = (): Omit<ServerInfo, 'client' | 'transport'>[] =>
};
});
const promptsWithEnabled = prompts.map((prompt) => {
const promptConfig = serverConfig?.prompts?.[prompt.name];
return {
...prompt,
description: promptConfig?.description || prompt.description, // Use custom description if available
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
};
});
return {
name,
status,
error,
tools: toolsWithEnabled,
prompts: promptsWithEnabled,
createTime,
enabled,
};
@@ -568,7 +619,7 @@ const getServerByName = (name: string): ServerInfo | undefined => {
};
// Filter tools by server configuration
const filterToolsByConfig = (serverName: string, tools: ToolInfo[]): ToolInfo[] => {
const filterToolsByConfig = (serverName: string, tools: Tool[]): Tool[] => {
const settings = loadSettings();
const serverConfig = settings.mcpServers[serverName];
@@ -948,7 +999,7 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
if (tool.name) {
const serverName = searchResults.find((r) => r.toolName === tool.name)?.serverName;
if (serverName) {
const enabledTools = filterToolsByConfig(serverName, [tool as ToolInfo]);
const enabledTools = filterToolsByConfig(serverName, [tool as Tool]);
return enabledTools.length > 0;
}
}
@@ -1139,6 +1190,119 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
}
};
export const handleGetPromptRequest = async (request: any, extra: any) => {
try {
const { name, arguments: promptArgs } = request.params;
let server: ServerInfo | undefined;
if (extra && extra.server) {
server = getServerByName(extra.server);
} else {
// Find the first server that has this tool
server = serverInfos.find(
(serverInfo) =>
serverInfo.status === 'connected' &&
serverInfo.enabled !== false &&
serverInfo.prompts.find((prompt) => prompt.name === name),
);
}
if (!server) {
throw new Error(`Server not found: ${name}`);
}
// Remove server prefix from prompt name if present
const cleanPromptName = name.startsWith(`${server.name}-`)
? name.replace(`${server.name}-`, '')
: name;
const promptParams = {
name: cleanPromptName || '',
arguments: promptArgs,
};
// Log the final promptParams
console.log(`Calling getPrompt with params: ${JSON.stringify(promptParams)}`);
const prompt = await server.client?.getPrompt(promptParams);
console.log(`Received prompt: ${JSON.stringify(prompt)}`);
if (!prompt) {
throw new Error(`Prompt not found: ${cleanPromptName}`);
}
return prompt;
} catch (error) {
console.error(`Error handling GetPromptRequest: ${error}`);
return {
content: [
{
type: 'text',
text: `Error: ${error}`,
},
],
isError: true,
};
}
};
export const handleListPromptsRequest = async (_: any, extra: any) => {
const sessionId = extra.sessionId || '';
const group = getGroup(sessionId);
console.log(`Handling ListPromptsRequest for group: ${group}`);
const allServerInfos = getDataService()
.filterData(serverInfos)
.filter((serverInfo) => {
if (serverInfo.enabled === false) return false;
if (!group) return true;
const serversInGroup = getServersInGroup(group);
if (!serversInGroup || serversInGroup.length === 0) return serverInfo.name === group;
return serversInGroup.includes(serverInfo.name);
});
const allPrompts: any[] = [];
for (const serverInfo of allServerInfos) {
if (serverInfo.prompts && serverInfo.prompts.length > 0) {
// Filter prompts based on server configuration
const settings = loadSettings();
const serverConfig = settings.mcpServers[serverInfo.name];
let enabledPrompts = serverInfo.prompts;
if (serverConfig && serverConfig.prompts) {
enabledPrompts = serverInfo.prompts.filter((prompt: any) => {
const promptConfig = serverConfig.prompts?.[prompt.name];
// If prompt is not in config, it's enabled by default
return promptConfig?.enabled !== false;
});
}
// If this is a group request, apply group-level prompt filtering
if (group) {
const serverConfigInGroup = getServerConfigInGroup(group, serverInfo.name);
if (
serverConfigInGroup &&
serverConfigInGroup.tools !== 'all' &&
Array.isArray(serverConfigInGroup.tools)
) {
// Note: Group config uses 'tools' field but we're filtering prompts here
// This might be a design decision to control access at the server level
}
}
// Apply custom descriptions from server configuration
const promptsWithCustomDescriptions = enabledPrompts.map((prompt: any) => {
const promptConfig = serverConfig?.prompts?.[prompt.name];
return {
...prompt,
description: promptConfig?.description || prompt.description, // Use custom description if available
};
});
allPrompts.push(...promptsWithCustomDescriptions);
}
}
return {
prompts: allPrompts,
};
};
// Create McpServer instance
export const createMcpServer = (name: string, version: string, group?: string): Server => {
// Determine server name based on routing type
@@ -1157,8 +1321,13 @@ export const createMcpServer = (name: string, version: string, group?: string):
}
// If no group, use default name (global routing)
const server = new Server({ name: serverName, version }, { capabilities: { tools: {} } });
const server = new Server(
{ name: serverName, version },
{ capabilities: { tools: {}, prompts: {}, resources: {} } },
);
server.setRequestHandler(ListToolsRequestSchema, handleListToolsRequest);
server.setRequestHandler(CallToolRequestSchema, handleCallToolRequest);
server.setRequestHandler(GetPromptRequestSchema, handleGetPromptRequest);
server.setRequestHandler(ListPromptsRequestSchema, handleListPromptsRequest);
return server;
};