feat: add tool management features including toggle and description updates (#163)

Co-authored-by: samanhappy@qq.com <my6051199>
This commit is contained in:
samanhappy
2025-06-04 16:03:45 +08:00
committed by GitHub
parent 4039a85ee1
commit 503b60edb7
12 changed files with 518 additions and 57 deletions

View File

@@ -4,12 +4,11 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ServerInfo, ServerConfig } from '../types/index.js';
import { ServerInfo, ServerConfig, ToolInfo } from '../types/index.js';
import { loadSettings, saveSettings, expandEnvVars, replaceEnvVars } from '../config/index.js';
import config from '../config/index.js';
import { getGroup } from './sseService.js';
import { getServersInGroup } from './groupService.js';
import { getSmartRoutingConfig } from '../utils/smartRouting.js';
import { saveToolsAsVectorEmbeddings, searchToolsByVector } from './vectorSearchService.js';
const servers: { [sessionId: string]: Server } = {};
@@ -51,6 +50,21 @@ export const notifyToolChanged = async () => {
});
};
export const syncToolEmbedding = async (serverName: string, toolName: string) => {
const serverInfo = getServerByName(serverName);
if (!serverInfo) {
console.warn(`Server not found: ${serverName}`);
return;
}
const tool = serverInfo.tools.find((t) => t.name === toolName);
if (!tool) {
console.warn(`Tool not found: ${toolName} on server: ${serverName}`);
return;
}
// Save tool as vector embedding for search
saveToolsAsVectorEmbeddings(serverName, [tool]);
};
// Store all server information
let serverInfos: ServerInfo[] = [];
@@ -189,27 +203,15 @@ export const initializeClientsFromSettings = (isInit: boolean): ServerInfo[] =>
}
serverInfo.tools = tools.tools.map((tool) => ({
name: tool.name,
name: name + '/' + tool.name,
description: tool.description || '',
inputSchema: tool.inputSchema || {},
}));
serverInfo.status = 'connected';
serverInfo.error = null;
// Save tools as vector embeddings for search (only when smart routing is enabled)
if (serverInfo.tools.length > 0) {
try {
const smartRoutingConfig = getSmartRoutingConfig();
if (smartRoutingConfig.enabled) {
console.log(
`Smart routing enabled - saving vector embeddings for server ${name}`,
);
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
}
} catch (vectorError) {
console.warn(`Failed to save vector embeddings for server ${name}:`, vectorError);
}
}
// Save tools as vector embeddings for search
saveToolsAsVectorEmbeddings(name, serverInfo.tools);
})
.catch((error) => {
console.error(
@@ -258,11 +260,22 @@ export const getServersInfo = (): Omit<ServerInfo, 'client' | 'transport'>[] =>
const infos = serverInfos.map(({ name, status, tools, createTime, error }) => {
const serverConfig = settings.mcpServers[name];
const enabled = serverConfig ? serverConfig.enabled !== false : true;
// Add enabled status and custom description to each tool
const toolsWithEnabled = tools.map((tool) => {
const toolConfig = serverConfig?.tools?.[tool.name];
return {
...tool,
description: toolConfig?.description || tool.description, // Use custom description if available
enabled: toolConfig?.enabled !== false, // Default to true if not explicitly disabled
};
});
return {
name,
status,
error,
tools,
tools: toolsWithEnabled,
createTime,
enabled,
};
@@ -279,6 +292,23 @@ const getServerByName = (name: string): ServerInfo | undefined => {
return serverInfos.find((serverInfo) => serverInfo.name === name);
};
// Filter tools by server configuration
const filterToolsByConfig = (serverName: string, tools: ToolInfo[]): ToolInfo[] => {
const settings = loadSettings();
const serverConfig = settings.mcpServers[serverName];
if (!serverConfig || !serverConfig.tools) {
// If no tool configuration exists, all tools are enabled by default
return tools;
}
return tools.filter((tool) => {
const toolConfig = serverConfig.tools?.[tool.name];
// If tool is not in config, it's enabled by default
return toolConfig?.enabled !== false;
});
};
// Get server by tool name
const getServerByTool = (toolName: string): ServerInfo | undefined => {
return serverInfos.find((serverInfo) => serverInfo.tools.some((tool) => tool.name === toolName));
@@ -489,7 +519,21 @@ Available servers: ${serversList}`;
const allTools = [];
for (const serverInfo of allServerInfos) {
if (serverInfo.tools && serverInfo.tools.length > 0) {
allTools.push(...serverInfo.tools);
// Filter tools based on server configuration and apply custom descriptions
const enabledTools = filterToolsByConfig(serverInfo.name, serverInfo.tools);
// Apply custom descriptions from configuration
const settings = loadSettings();
const serverConfig = settings.mcpServers[serverInfo.name];
const toolsWithCustomDescriptions = enabledTools.map((tool) => {
const toolConfig = serverConfig?.tools?.[tool.name];
return {
...tool,
description: toolConfig?.description || tool.description, // Use custom description if available
};
});
allTools.push(...toolsWithCustomDescriptions);
}
}
@@ -530,30 +574,54 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
const searchResults = await searchToolsByVector(query, limitNum, thresholdNum, servers);
console.log(`Search results: ${JSON.stringify(searchResults)}`);
// Find actual tool information from serverInfos by serverName and toolName
const tools = searchResults.map((result) => {
// Find the server in serverInfos
const server = serverInfos.find(
(serverInfo) =>
serverInfo.name === result.serverName &&
serverInfo.status === 'connected' &&
serverInfo.enabled !== false,
);
if (server && server.tools && server.tools.length > 0) {
// Find the tool in server.tools
const actualTool = server.tools.find((tool) => tool.name === result.toolName);
if (actualTool) {
// Return the actual tool info from serverInfos
return actualTool;
}
}
const tools = searchResults
.map((result) => {
// Find the server in serverInfos
const server = serverInfos.find(
(serverInfo) =>
serverInfo.name === result.serverName &&
serverInfo.status === 'connected' &&
serverInfo.enabled !== false,
);
if (server && server.tools && server.tools.length > 0) {
// Find the tool in server.tools
const actualTool = server.tools.find((tool) => tool.name === result.toolName);
if (actualTool) {
// Check if the tool is enabled in configuration
const enabledTools = filterToolsByConfig(server.name, [actualTool]);
if (enabledTools.length > 0) {
// Apply custom description from configuration
const settings = loadSettings();
const serverConfig = settings.mcpServers[server.name];
const toolConfig = serverConfig?.tools?.[actualTool.name];
// Fallback to search result if server or tool not found
return {
name: result.toolName,
description: result.description || '',
inputSchema: result.inputSchema || {},
};
});
// Return the actual tool info from serverInfos with custom description
return {
...actualTool,
description: toolConfig?.description || actualTool.description,
};
}
}
}
// Fallback to search result if server or tool not found or disabled
return {
name: result.toolName,
description: result.description || '',
inputSchema: result.inputSchema || {},
};
})
.filter((tool) => {
// Additional filter to remove tools that are disabled
if (tool.name) {
const serverName = searchResults.find((r) => r.toolName === tool.name)?.serverName;
if (serverName) {
const enabledTools = filterToolsByConfig(serverName, [tool as ToolInfo]);
return enabledTools.length > 0;
}
}
return true; // Keep fallback results
});
// Add usage guidance to the response
const response = {
@@ -586,7 +654,7 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
// Special handling for call_tool
if (request.params.name === 'call_tool') {
const { toolName, arguments: toolArgs = {} } = request.params.arguments || {};
let { toolName, arguments: toolArgs = {} } = request.params.arguments || {};
if (!toolName) {
throw new Error('toolName parameter is required');
@@ -631,6 +699,9 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
`Invoking tool '${toolName}' on server '${targetServerInfo.name}' with arguments: ${JSON.stringify(finalArgs)}`,
);
toolName = toolName.startsWith(`${targetServerInfo.name}/`)
? toolName.replace(`${targetServerInfo.name}/`, '')
: toolName;
const result = await client.callTool({
name: toolName,
arguments: finalArgs,
@@ -649,6 +720,10 @@ export const handleCallToolRequest = async (request: any, extra: any) => {
if (!client) {
throw new Error(`Client not found for server: ${request.params.name}`);
}
request.params.name = request.params.name.startsWith(`${serverInfo.name}/`)
? request.params.name.replace(`${serverInfo.name}/`, '')
: request.params.name;
const result = await client.callTool(request.params);
console.log(`Tool call result: ${JSON.stringify(result)}`);
return result;