class StreamingToolExecutor:
"""流式工具执行器:在响应流中处理工具调用"""
def __init__(self, tool_registry: ToolRegistry, max_concurrent: int = 5):
self.tool_registry = tool_registry
self.max_concurrent = max_concurrent
self._pending_executions: Dict[str, asyncio.Task] = {}
async def execute_tools(
self,
tool_uses: List[ToolUseBlock],
app_state: AppState
) -> Dict[str, ToolResultBlock]:
"""
并发执行多个工具调用,返回所有结果
使用 asyncio.Semaphore 限制并发数
"""
semaphore = asyncio.Semaphore(self.max_concurrent)
tasks = []
for tool_use in tool_uses:
task = self._execute_single_tool(tool_use, app_state, semaphore)
tasks.append(task)
self._pending_executions[tool_use.id] = task
# 并发等待所有工具完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 组织结果
tool_results = {}
for tool_use, result in zip(tool_uses, results):
if isinstance(result, Exception):
tool_results[tool_use.id] = ToolResultBlock(
tool_use_id=tool_use.id,
content=f"Tool execution error: {str(result)}",
is_error=True,
error_type=type(result).__name__
)
else:
tool_results[tool_use.id] = result
return tool_results
async def _execute_single_tool(
self,
tool_use: ToolUseBlock,
app_state: AppState,
semaphore: asyncio.Semaphore
) -> ToolResultBlock:
"""执行单个工具"""
async with semaphore:
try:
# 1. 查找工具
tool = self.tool_registry.get(tool_use.name)
if not tool:
raise ToolNotFoundError(f"Tool '{tool_use.name}' not found")
# 2. 权限检查
if not tool.check_permissions(app_state):
raise PermissionDeniedError(
f"Permission denied for tool '{tool_use.name}'"
)
# 3. 执行工具
result = await tool.call(tool_use.input)
# 4. 组织结果
return ToolResultBlock(
tool_use_id=tool_use.id,
content=str(result),
is_error=False
)
except Exception as e:
return ToolResultBlock(
tool_use_id=tool_use.id,
content=str(e),
is_error=True,
error_type=type(e).__name__
)
def get_tool_progress(self, tool_use_id: str) -> Optional[ToolProgress]:
"""获取某个工具的执行进度(如果可用)"""
task = self._pending_executions.get(tool_use_id)
if not task:
return None
# 实现细节:从工具对象的进度属性中提取(如果工具支持)
# 通常通过工具返回的结果对象或单独的进度查询接口获取
# 这是一个简化的实现,实际需要根据具体工具的API调整
if hasattr(task, '_progress'):
return task._progress
return None