networkx - 可达节点集合
import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
# 构造一个简单的数据对象
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
data = Data(edge_index=edge_index)
# 将数据对象转换为 NetworkX 图形对象
graph = to_networkx(data)
# 计算可达节点数量
reachable_nodes = nx.descendants(graph, 0)
num_reachable_nodes = len(reachable_nodes)
print(num_reachable_nodes)

浙公网安备 33010602011771号