from typing import Dict, List, Set, Optional
from dataclasses import dataclass, field
from enum import Enum
import json
class TaskType(Enum):
"""7种Task类型"""
LOCAL_BASH = "local_bash"
LOCAL_AGENT = "local_agent"
REMOTE_AGENT = "remote_agent"
IN_PROCESS_TEAMMATE = "in_process_teammate"
WORKFLOW = "workflow"
MONITOR_MCP = "monitor_mcp"
DREAM = "dream"
class TaskState(Enum):
"""任务生命周期状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
KILLED = "killed"
@dataclass
class TaskDefinition:
"""任务定义"""
task_id: str
task_type: TaskType
description: str
dependencies: List[str] = field(default_factory=list)
timeout_seconds: int = 300
max_retries: int = 1
execution_config: Dict = field(default_factory=dict)
def add_dependency(self, task_id: str) -> None:
"""添加任务依赖"""
if task_id not in self.dependencies:
self.dependencies.append(task_id)
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"task_type": self.task_type.value,
"description": self.description,
"dependencies": self.dependencies,
"timeout_seconds": self.timeout_seconds,
"max_retries": self.max_retries,
"execution_config": self.execution_config,
}
class TaskDAG:
"""任务依赖图"""
def __init__(self):
self.tasks: Dict[str, TaskDefinition] = {}
self.execution_order: List[str] = []
def add_task(self, task_def: TaskDefinition) -> None:
"""添加任务定义"""
self.tasks[task_def.task_id] = task_def
def validate(self) -> bool:
"""验证DAG的合法性"""
# 1. 检查所有依赖都存在
for task_id, task in self.tasks.items():
for dep_id in task.dependencies:
if dep_id not in self.tasks:
raise ValueError(f"任务{task_id}依赖的任务{dep_id}不存在")
# 2. 检查是否存在循环依赖
if self._has_cycle():
raise ValueError("检测到循环依赖")
return True
def _has_cycle(self) -> bool:
"""DFS检测循环依赖"""
visited = set()
rec_stack = set()
def dfs(node: str) -> bool:
visited.add(node)
rec_stack.add(node)
for neighbor in self.tasks[node].dependencies:
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
return True
rec_stack.remove(node)
return False
for task_id in self.tasks:
if task_id not in visited:
if dfs(task_id):
return True
return False
def topological_sort(self) -> List[str]:
"""拓扑排序,生成执行顺序"""
self.validate()
# 入度表
in_degree = {task_id: 0 for task_id in self.tasks}
for task in self.tasks.values():
for dep_id in task.dependencies:
in_degree[task.task_id] += 1
# Kahn算法
queue = [task_id for task_id in self.tasks if in_degree[task_id] == 0]
order = []
while queue:
node = queue.pop(0)
order.append(node)
# 找到依赖于这个节点的所有任务
for task_id, task in self.tasks.items():
if node in task.dependencies:
in_degree[task_id] -= 1
if in_degree[task_id] == 0:
queue.append(task_id)
self.execution_order = order
return order
def get_parallelizable_groups(self) -> List[List[str]]:
"""获取可并行执行的任务组"""
# 断言:DAG必须已验证(无循环、无孤立节点)
assert self.validate(), "DAG验证失败:存在循环或缺失依赖"
self.topological_sort()
groups = []
completed = set()
for task_id in self.execution_order:
task = self.tasks[task_id]
# 检查所有依赖是否已完成
if all(dep in completed for dep in task.dependencies):
# 找到可以和这个任务并行的任务
current_group = [task_id]
completed.add(task_id)
for other_id in self.execution_order[len(completed):]:
other = self.tasks[other_id]
if all(dep in completed for dep in other.dependencies):
current_group.append(other_id)
completed.add(other_id)
if current_group:
groups.append(current_group)
# 断言:所有任务都应被分配到某个组(无孤立节点)
total_assigned = sum(len(g) for g in groups)
assert total_assigned == len(self.tasks), \
f"孤立节点:已分配{total_assigned}/{len(self.tasks)}个任务"
return groups
def visualize(self) -> str:
"""生成Mermaid图表代码"""
lines = ["graph TD"]
# 添加节点
for task_id, task in self.tasks.items():
label = f"{task.task_id}<br/>({task.task_type.value})"
lines.append(f' {task_id.replace("#", "_")}["{label}"]')
# 添加边
for task_id, task in self.tasks.items():
for dep_id in task.dependencies:
lines.append(f" {dep_id.replace('#', '_')} --> {task_id.replace('#', '_')}")
return "\n".join(lines)
# 使用示例
if __name__ == "__main__":
dag = TaskDAG()
# 添加任务
dag.add_task(TaskDefinition(
task_id="task_0",
task_type=TaskType.LOCAL_BASH,
description="获取论文数据"
))
dag.add_task(TaskDefinition(
task_id="task_1",
task_type=TaskType.LOCAL_AGENT,
description="内容分析",
dependencies=["task_0"]
))
dag.add_task(TaskDefinition(
task_id="task_2",
task_type=TaskType.LOCAL_AGENT,
description="引用分析",
dependencies=["task_0"]
))
dag.add_task(TaskDefinition(
task_id="task_3",
task_type=TaskType.LOCAL_AGENT,
description="综合评估",
dependencies=["task_1", "task_2"]
))
# 验证和排序
order = dag.topological_sort()
print("执行顺序:", order)
# 获取并行组
groups = dag.get_parallelizable_groups()
print("并行组:", groups)
# 生成可视化
print(dag.visualize())