9.4 Harness中的MCP集成模式
9.4.1 系统级集成的核心挑战
9.4.2 动态工具注册与发现
MCPToolRegistry
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
import asyncio
import hashlib
import json
from datetime import datetime, timedelta
@dataclass
class MCPServerConfig:
"""MCP Server配置"""
server_id: str
server_name: str
transport_type: str # "stdio" | "http"
endpoint: str # 路径或URL
enabled: bool = True
priority: int = 0 # 优先级(用于多个Server提供相同工具时)
timeout_seconds: int = 30
max_retries: int = 2
tags: List[str] = field(default_factory=list) # 标签化分类
@dataclass
class ToolSchema:
"""缓存的工具Schema"""
server_id: str
tool_name: str
description: str
input_schema: Dict
cached_at: datetime
schema_hash: str
class MCPToolRegistry:
"""MCP工具注册中心"""
def __init__(self):
self.servers: Dict[str, MCPServerConfig] = {}
self.tool_cache: Dict[str, ToolSchema] = {}
self.server_clients: Dict[str, object] = {}
self.cache_ttl_seconds = 3600 # Schema缓存1小时
self.lock = asyncio.Lock()
self.permission_config: Dict[str, Dict[str, List[str]]] = {} # agent_id -> server_id -> [tool_names]
async def register_server(self, config: MCPServerConfig) -> None:
"""注册MCP Server"""
async with self.lock:
self.servers[config.server_id] = config
print(f"[Registry] Registered MCP Server: {config.server_name}")
async def unregister_server(self, server_id: str) -> None:
"""注销MCP Server"""
async with self.lock:
if server_id in self.servers:
del self.servers[server_id]
if server_id in self.server_clients:
# 关闭连接
client = self.server_clients[server_id]
if hasattr(client, 'close'):
await client.close()
del self.server_clients[server_id]
# 清除缓存
self.tool_cache = {
k: v for k, v in self.tool_cache.items()
if v.server_id != server_id
}
async def discover_tools(self) -> Dict[str, List[str]]:
"""发现所有可用的工具"""
tools_by_server = {}
for server_id, config in self.servers.items():
if not config.enabled:
continue
try:
client = await self._get_client(server_id)
response = await client.send_request("tools/list")
tools = [tool["name"] for tool in response.get("result", {}).get("tools", [])]
tools_by_server[server_id] = tools
except Exception as e:
print(f"[Registry] Error discovering tools from {server_id}: {e}")
tools_by_server[server_id] = []
return tools_by_server
async def get_tool_schema(self, tool_name: str, server_id: Optional[str] = None) -> Optional[ToolSchema]:
"""获取工具Schema(支持缓存)"""
async with self.lock:
# 尝试从缓存获取
cache_key = f"{server_id}#{tool_name}" if server_id else tool_name
if cache_key in self.tool_cache:
cached = self.tool_cache[cache_key]
if datetime.now() - cached.cached_at < timedelta(seconds=self.cache_ttl_seconds):
return cached
# 缓存未命中,从Server获取
if server_id:
servers_to_try = [server_id]
else:
# 尝试所有提供此工具的Server
servers_to_try = []
for sid, config in self.servers.items():
if config.enabled:
servers_to_try.append(sid)
for sid in servers_to_try:
try:
client = await self._get_client(sid)
response = await client.send_request("tools/list")
for tool in response.get("result", {}).get("tools", []):
if tool["name"] == tool_name:
schema = ToolSchema(
server_id=sid,
tool_name=tool_name,
description=tool["description"],
input_schema=tool["inputSchema"],
cached_at=datetime.now(),
schema_hash=self._hash_schema(tool),
)
# 缓存
self.tool_cache[f"{sid}#{tool_name}"] = schema
return schema
except Exception as e:
print(f"[Registry] Error getting schema from {sid}: {e}")
continue
return None
async def call_tool(
self,
tool_name: str,
arguments: Dict,
agent_id: Optional[str] = None,
server_id: Optional[str] = None,
) -> Tuple[bool, any]:
"""调用工具"""
try:
# 确定使用哪个Server
if not server_id:
server_id = await self._find_server_for_tool(tool_name)
if not server_id:
return False, f"Tool {tool_name} not found in any server"
# 检查权限
if not await self._check_permission(agent_id, server_id, tool_name):
return False, f"Agent {agent_id} not authorized to call {tool_name}"
# 获取client并调用
client = await self._get_client(server_id)
response = await client.send_request(
"tools/call",
{"name": tool_name, "arguments": arguments}
)
if "error" in response:
return False, response["error"]["message"]
return True, response.get("result")
except Exception as e:
return False, str(e)
async def _get_client(self, server_id: str):
"""获取或创建Server客户端"""
if server_id in self.server_clients:
return self.server_clients[server_id]
config = self.servers.get(server_id)
if not config:
raise ValueError(f"Server {server_id} not found")
if config.transport_type == "stdio":
from mcp_client import StdioMCPClient
client = StdioMCPClient(config.endpoint)
client.start()
elif config.transport_type == "http":
from mcp_client import HttpMCPClient
client = HttpMCPClient(config.endpoint)
await client.connect()
else:
raise ValueError(f"Unknown transport type: {config.transport_type}")
self.server_clients[server_id] = client
return client
async def _find_server_for_tool(self, tool_name: str) -> Optional[str]:
"""找到提供某个工具的Server"""
tools_by_server = await self.discover_tools()
# 按优先级排序
candidates = [
(sid, self.servers[sid].priority)
for sid, tools in tools_by_server.items()
if tool_name in tools
]
if candidates:
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[0][0]
return None
async def _check_permission(self, agent_id: Optional[str], server_id: str, tool_name: str) -> bool:
"""检查Agent是否有权限调用工具"""
if agent_id is None:
return False # 匿名调用默认拒绝
# 查询权限配置
allowed_tools = self.permission_config.get(agent_id, {}).get(server_id, [])
# 支持通配符
if "*" in allowed_tools or tool_name in allowed_tools:
return True
print(f"[Permission Denied] agent={agent_id}, server={server_id}, tool={tool_name}")
return False
def _hash_schema(self, tool: Dict) -> str:
"""计算Schema的哈希值,用于判断是否变化"""
schema_str = json.dumps(tool["inputSchema"], sort_keys=True)
return hashlib.md5(schema_str.encode()).hexdigest()
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
return {
"total_cached_tools": len(self.tool_cache),
"registered_servers": len(self.servers),
"active_clients": len(self.server_clients),
"cache_memory_bytes": sum(len(json.dumps(v.input_schema)) for v in self.tool_cache.values()),
}Schema缓存策略
权限与审计网关
错误处理与降级策略
本小节小结
最后更新于
