Graph PyTorch Geometric 库
PyTorch Geometric (PyG)
1. 安装
安装:site
import os
import torch
print(torch.__version__)
os.environ['TORCH'] = torch.__version__
! pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
! pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
! pip install torch-geometric
# ! pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
! pip install ogb
2. torch_geometric.datasets 和 torch_geometric.data 模块
2.1 加载数据集
实例代码:
from torch_geometric.datasets import Planetoid
root = './tmp/cora' # 数据保存路径
name = 'Cora' # 数据集名称
dataset = Planetoid(root=root, name=name)
print(dataset)
print(len(dataset))
print(dataset[0])
from torch_geometric.datasets import TUDataset
root = './tmp/ENZYMES'
name = 'ENZYMES'
dataset= TUDataset(root=root, name=name)
print(dataset)
print(len(dataset))
print(dataset[0])
-
每个
dataset对象是一系列 graph 的集合 -
可以通过索引
dataset[0]获取一个 graph- 返回
torch_geometric.data类型
- 返回
2.2 torch_geometric.dataset 模块
torch_geometric.dataset 类常用属性
# graph related
print("No. of graphs: ", len(dataset))
print("No. of graph classes: ", dataset.num_classes)
# node related
print("No. of node labels: ", dataset.num_node_labels)
print("No. of node features: ", dataset.num_features)
print("No. of node features: ", dataset.num_node_features)
print("No. of node attributes: ", dataset.num_node_attributes)
# edge related
print("No. of edge labels: ", dataset.num_edge_labels)
print("No. of edge features: ", dataset.num_edge_features)
print("No. of edge attributes: ", dataset.num_edge_attributes)
2.3 torch_geometric.data 模块
2.3.1 torch_geometric.data 类常用属性
data = dataset[0]
# graph related
print(data)
print("Graph is directed:", data.is_directed())
print("Graph is undirected:", data.is_undirected())
print("Contains isolated nodes:", data.has_isolated_nodes())
print("Contains self-loops:", data.has_self_loops())
# node related
print("No. of nodes: ", data.num_nodes)
print("No. of node features: ", data.num_features)
print("No. of node features: ", data.num_node_features)
# edge related
print("No. of edges: ", data.num_edges)
print("No. of edge features: ", data.num_edge_features)
print("No. of training nodes: ", data.train_mask.sum().item())
print(data.y.unique().size()[0])
2.4 自定义数据集
2.4.1 自定义 data 类
参考资料:Data Handling of Graphs, PyG Tutorial
实例代码: 创建一个 Data 类型的数据
import torch
from torch_geometric.data import Data
# edge index : (2, E)
# two undirected edges: 0-1, 1-2
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
# 等价于:
# edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long)
# edge_index = edge_index.t().contiguous()
# node feature : (N, D)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
#
data_1 = Data(x=x, edge_index=edge_index)
实例代码: undirected graph 和 directed graph
# three directed edges: 0->1, 1->0, 1->2
edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long)
data_2 = Data(x=x, edge_index=edge_index)
print("Graph 1 is directed: ", data_1.is_directed()) # False
print("Graph 2 is directed: ", data_2.is_directed()) # True
-
edge_index: 为 edge 信息-
Size 为 \((2, |\mathcal{E}|)\) 的
tensor,数据类型为torch.long -
对于 undirected graph,同一个 edge 的正向和反方都应包括在内
-
2.4.2 自定义 datasets 类
-
方便保存和加载数据
-
将一系列的 graph 打包
-
- TUDataset 类源码
(1) 保存 Dataset
class MyDataset(InMemoryDataset):
def __init__(self, root, data_list):
self.data_list = data_list
super().__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
return ['data.pt']
def process(self):
torch.save(self.collate(self.data_list), self.processed_paths[0])
path = "/mydatasets" # 保存的数据文件路径
data_list = [data_1, data_2, ...] # 一系列 graph 数据,torch_geometric.data.Data 类型
dataset = MyDataset(root=path, data_list=data_list) # 初始化时,自动保存数据文件
print(dataset[0], len(dataset))
(2) 读取 Dataset
class MyDataset(InMemoryDataset):
def __init__(self, root):
super().__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
return ['data.pt']
path = "/mydatasets" # 保存的数据文件路径
dataset = MyDataset(root=path) # 初始化时,自动加载数据文件
print(dataset[0], len(dataset))
3. torch_geometric.utils 模块
3.1 edge index 和 adjacency matrix
to_dense_adj(edge_index, **args): edge_index -> adjacency matrix
dense_to_sparse(adj, **args): adjacency matrix -> edge_index
实例代码:to_dense_adj() 函数和 dense_to_sparse() 函数
import torch_geometric.utils as tgu
data = dataset[0]
adj = tgu.to_dense_adj(edge_index = data.edge_index)
print('size of adj:', adj.size())
edge_index = tgu.dense_to_sparse(adj)
print('type of edge_index:', type(edge_index))
edge_index = edge_index[0]
print('size of edge_index[0]:', edge_index.size())
3.2 Data 和 networkx.Graph
to_networkx(data, **args): Data -> networkx.Graph or networkx.DiGraph
from_networkx(G): networkx.Graph or networkx.DiGraph -> Data
实例代码:
import networkx as nx
G = tgu.to_networkx(data)
pos = nx.spring_layout(G, seed=648) # Seed layout for reproducible node positions
nx.draw(G, pos, with_labels=True)
G = nx.Graph()
G.add_weighted_edges_from([(0, 1, 0.6), (1, 2, 0.2), (1, 3, 0.1)])
data = tgu.from_networkx(G)
print(data)

浙公网安备 33010602011771号