mirror of
https://github.com/samanhappy/mcphub.git
synced 2025-12-31 20:00:00 -05:00
Add prompt management functionality to MCP server (#281)
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
ListToolsRequestSchema,
|
||||
ListPromptsRequestSchema,
|
||||
GetPromptRequestSchema,
|
||||
ServerCapabilities,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { ServerInfo, ServerConfig, ToolInfo } from '../types/index.js';
|
||||
import { ServerInfo, ServerConfig, Tool } from '../types/index.js';
|
||||
import { loadSettings, saveSettings, expandEnvVars, replaceEnvVars } from '../config/index.js';
|
||||
import config from '../config/index.js';
|
||||
import { getGroup } from './sseService.js';
|
||||
@@ -343,6 +349,7 @@ export const initializeClientsFromSettings = async (
|
||||
status: 'disconnected',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
enabled: false,
|
||||
});
|
||||
@@ -376,6 +383,7 @@ export const initializeClientsFromSettings = async (
|
||||
status: 'disconnected',
|
||||
error: 'Missing OpenAPI specification URL or schema',
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
});
|
||||
continue;
|
||||
@@ -388,6 +396,7 @@ export const initializeClientsFromSettings = async (
|
||||
status: 'connecting',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
createTime: Date.now(),
|
||||
enabled: conf.enabled === undefined ? true : conf.enabled,
|
||||
};
|
||||
@@ -404,7 +413,7 @@ export const initializeClientsFromSettings = async (
|
||||
|
||||
// Convert OpenAPI tools to MCP tool format
|
||||
const openApiTools = openApiClient.getTools();
|
||||
const mcpTools: ToolInfo[] = openApiTools.map((tool) => ({
|
||||
const mcpTools: Tool[] = openApiTools.map((tool) => ({
|
||||
name: `${name}-${tool.name}`,
|
||||
description: tool.description,
|
||||
inputSchema: cleanInputSchema(tool.inputSchema),
|
||||
@@ -469,6 +478,7 @@ export const initializeClientsFromSettings = async (
|
||||
status: 'connecting',
|
||||
error: null,
|
||||
tools: [],
|
||||
prompts: [],
|
||||
client,
|
||||
transport,
|
||||
options: requestOptions,
|
||||
@@ -480,32 +490,63 @@ export const initializeClientsFromSettings = async (
|
||||
.connect(transport, initRequestOptions || requestOptions)
|
||||
.then(() => {
|
||||
console.log(`Successfully connected client for server: ${name}`);
|
||||
client
|
||||
.listTools({}, initRequestOptions || requestOptions)
|
||||
.then((tools) => {
|
||||
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
|
||||
const capabilities: ServerCapabilities | undefined = client.getServerCapabilities();
|
||||
console.log(`Server capabilities: ${JSON.stringify(capabilities)}`);
|
||||
|
||||
serverInfo.tools = tools.tools.map((tool) => ({
|
||||
name: `${name}-${tool.name}`,
|
||||
description: tool.description || '',
|
||||
inputSchema: cleanInputSchema(tool.inputSchema || {}),
|
||||
}));
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.error = null;
|
||||
let dataError: Error | null = null;
|
||||
if (capabilities?.tools) {
|
||||
client
|
||||
.listTools({}, initRequestOptions || requestOptions)
|
||||
.then((tools) => {
|
||||
console.log(`Successfully listed ${tools.tools.length} tools for server: ${name}`);
|
||||
serverInfo.tools = tools.tools.map((tool) => ({
|
||||
name: `${name}-${tool.name}`,
|
||||
description: tool.description || '',
|
||||
inputSchema: cleanInputSchema(tool.inputSchema || {}),
|
||||
}));
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
// Set up keep-alive ping for SSE connections
|
||||
setupKeepAlive(serverInfo, conf);
|
||||
if (capabilities?.prompts) {
|
||||
client
|
||||
.listPrompts({}, initRequestOptions || requestOptions)
|
||||
.then((prompts) => {
|
||||
console.log(
|
||||
`Successfully listed ${prompts.prompts.length} prompts for server: ${name}`,
|
||||
);
|
||||
serverInfo.prompts = prompts.prompts.map((prompt) => ({
|
||||
name: `${name}-${prompt.name}`,
|
||||
title: prompt.title,
|
||||
description: prompt.description,
|
||||
arguments: prompt.arguments,
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list prompts for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
dataError = error;
|
||||
});
|
||||
}
|
||||
|
||||
// Save tools as vector embeddings for search
|
||||
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
`Failed to list tools for server ${name} by error: ${error} with stack: ${error.stack}`,
|
||||
);
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to list tools: ${error.stack} `;
|
||||
});
|
||||
if (!dataError) {
|
||||
serverInfo.status = 'connected';
|
||||
serverInfo.error = null;
|
||||
|
||||
// Set up keep-alive ping for SSE connections
|
||||
setupKeepAlive(serverInfo, conf);
|
||||
} else {
|
||||
serverInfo.status = 'disconnected';
|
||||
serverInfo.error = `Failed to list data: ${dataError} `;
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
@@ -532,7 +573,7 @@ export const getServersInfo = (): Omit<ServerInfo, 'client' | 'transport'>[] =>
|
||||
const filterServerInfos: ServerInfo[] = dataService.filterData
|
||||
? dataService.filterData(serverInfos)
|
||||
: serverInfos;
|
||||
const infos = filterServerInfos.map(({ name, status, tools, createTime, error }) => {
|
||||
const infos = filterServerInfos.map(({ name, status, tools, prompts, createTime, error }) => {
|
||||
const serverConfig = settings.mcpServers[name];
|
||||
const enabled = serverConfig ? serverConfig.enabled !== false : true;
|
||||
|
||||
@@ -546,11 +587,21 @@ export const getServersInfo = (): Omit<ServerInfo, 'client' | 'transport'>[] =>
|
||||
};
|
||||
});
|
||||
|
||||
const promptsWithEnabled = prompts.map((prompt) => {
|
||||
const promptConfig = serverConfig?.prompts?.[prompt.name];
|
||||
return {
|
||||
...prompt,
|
||||
description: promptConfig?.description || prompt.description, // Use custom description if available
|
||||
enabled: promptConfig?.enabled !== false, // Default to true if not explicitly disabled
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
name,
|
||||
status,
|
||||
error,
|
||||
tools: toolsWithEnabled,
|
||||
prompts: promptsWithEnabled,
|
||||
createTime,
|
||||
enabled,
|
||||
};
|
||||
@@ -568,7 +619,7 @@ const getServerByName = (name: string): ServerInfo | undefined => {
|
||||
};
|
||||
|
||||
// Filter tools by server configuration
|
||||
const filterToolsByConfig = (serverName: string, tools: ToolInfo[]): ToolInfo[] => {
|
||||
const filterToolsByConfig = (serverName: string, tools: Tool[]): Tool[] => {
|
||||
const settings = loadSettings();
|
||||
const serverConfig = settings.mcpServers[serverName];
|
||||
|
||||
@@ -948,7 +999,7 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
|
||||
if (tool.name) {
|
||||
const serverName = searchResults.find((r) => r.toolName === tool.name)?.serverName;
|
||||
if (serverName) {
|
||||
const enabledTools = filterToolsByConfig(serverName, [tool as ToolInfo]);
|
||||
const enabledTools = filterToolsByConfig(serverName, [tool as Tool]);
|
||||
return enabledTools.length > 0;
|
||||
}
|
||||
}
|
||||
@@ -1139,6 +1190,119 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const handleGetPromptRequest = async (request: any, extra: any) => {
|
||||
try {
|
||||
const { name, arguments: promptArgs } = request.params;
|
||||
let server: ServerInfo | undefined;
|
||||
if (extra && extra.server) {
|
||||
server = getServerByName(extra.server);
|
||||
} else {
|
||||
// Find the first server that has this tool
|
||||
server = serverInfos.find(
|
||||
(serverInfo) =>
|
||||
serverInfo.status === 'connected' &&
|
||||
serverInfo.enabled !== false &&
|
||||
serverInfo.prompts.find((prompt) => prompt.name === name),
|
||||
);
|
||||
}
|
||||
if (!server) {
|
||||
throw new Error(`Server not found: ${name}`);
|
||||
}
|
||||
|
||||
// Remove server prefix from prompt name if present
|
||||
const cleanPromptName = name.startsWith(`${server.name}-`)
|
||||
? name.replace(`${server.name}-`, '')
|
||||
: name;
|
||||
|
||||
const promptParams = {
|
||||
name: cleanPromptName || '',
|
||||
arguments: promptArgs,
|
||||
};
|
||||
// Log the final promptParams
|
||||
console.log(`Calling getPrompt with params: ${JSON.stringify(promptParams)}`);
|
||||
const prompt = await server.client?.getPrompt(promptParams);
|
||||
console.log(`Received prompt: ${JSON.stringify(prompt)}`);
|
||||
if (!prompt) {
|
||||
throw new Error(`Prompt not found: ${cleanPromptName}`);
|
||||
}
|
||||
|
||||
return prompt;
|
||||
} catch (error) {
|
||||
console.error(`Error handling GetPromptRequest: ${error}`);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error: ${error}`,
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const handleListPromptsRequest = async (_: any, extra: any) => {
|
||||
const sessionId = extra.sessionId || '';
|
||||
const group = getGroup(sessionId);
|
||||
console.log(`Handling ListPromptsRequest for group: ${group}`);
|
||||
|
||||
const allServerInfos = getDataService()
|
||||
.filterData(serverInfos)
|
||||
.filter((serverInfo) => {
|
||||
if (serverInfo.enabled === false) return false;
|
||||
if (!group) return true;
|
||||
const serversInGroup = getServersInGroup(group);
|
||||
if (!serversInGroup || serversInGroup.length === 0) return serverInfo.name === group;
|
||||
return serversInGroup.includes(serverInfo.name);
|
||||
});
|
||||
|
||||
const allPrompts: any[] = [];
|
||||
for (const serverInfo of allServerInfos) {
|
||||
if (serverInfo.prompts && serverInfo.prompts.length > 0) {
|
||||
// Filter prompts based on server configuration
|
||||
const settings = loadSettings();
|
||||
const serverConfig = settings.mcpServers[serverInfo.name];
|
||||
|
||||
let enabledPrompts = serverInfo.prompts;
|
||||
if (serverConfig && serverConfig.prompts) {
|
||||
enabledPrompts = serverInfo.prompts.filter((prompt: any) => {
|
||||
const promptConfig = serverConfig.prompts?.[prompt.name];
|
||||
// If prompt is not in config, it's enabled by default
|
||||
return promptConfig?.enabled !== false;
|
||||
});
|
||||
}
|
||||
|
||||
// If this is a group request, apply group-level prompt filtering
|
||||
if (group) {
|
||||
const serverConfigInGroup = getServerConfigInGroup(group, serverInfo.name);
|
||||
if (
|
||||
serverConfigInGroup &&
|
||||
serverConfigInGroup.tools !== 'all' &&
|
||||
Array.isArray(serverConfigInGroup.tools)
|
||||
) {
|
||||
// Note: Group config uses 'tools' field but we're filtering prompts here
|
||||
// This might be a design decision to control access at the server level
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom descriptions from server configuration
|
||||
const promptsWithCustomDescriptions = enabledPrompts.map((prompt: any) => {
|
||||
const promptConfig = serverConfig?.prompts?.[prompt.name];
|
||||
return {
|
||||
...prompt,
|
||||
description: promptConfig?.description || prompt.description, // Use custom description if available
|
||||
};
|
||||
});
|
||||
|
||||
allPrompts.push(...promptsWithCustomDescriptions);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
prompts: allPrompts,
|
||||
};
|
||||
};
|
||||
|
||||
// Create McpServer instance
|
||||
export const createMcpServer = (name: string, version: string, group?: string): Server => {
|
||||
// Determine server name based on routing type
|
||||
@@ -1157,8 +1321,13 @@ export const createMcpServer = (name: string, version: string, group?: string):
|
||||
}
|
||||
// If no group, use default name (global routing)
|
||||
|
||||
const server = new Server({ name: serverName, version }, { capabilities: { tools: {} } });
|
||||
const server = new Server(
|
||||
{ name: serverName, version },
|
||||
{ capabilities: { tools: {}, prompts: {}, resources: {} } },
|
||||
);
|
||||
server.setRequestHandler(ListToolsRequestSchema, handleListToolsRequest);
|
||||
server.setRequestHandler(CallToolRequestSchema, handleCallToolRequest);
|
||||
server.setRequestHandler(GetPromptRequestSchema, handleGetPromptRequest);
|
||||
server.setRequestHandler(ListPromptsRequestSchema, handleListPromptsRequest);
|
||||
return server;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user