Graph PyTorch Geometric 库

PyTorch Geometric (PyG)

1. 安装

安装:site

import os
import torch
print(torch.__version__)
os.environ['TORCH'] = torch.__version__

! pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
! pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
! pip install torch-geometric
# ! pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
! pip install ogb

2. torch_geometric.datasetstorch_geometric.data 模块

2.1 加载数据集

实例代码

from torch_geometric.datasets import Planetoid
root = './tmp/cora'   # 数据保存路径
name = 'Cora'         # 数据集名称
dataset = Planetoid(root=root, name=name)
print(dataset)
print(len(dataset))
print(dataset[0])

from torch_geometric.datasets import TUDataset
root = './tmp/ENZYMES'
name = 'ENZYMES'
dataset= TUDataset(root=root, name=name)
print(dataset)
print(len(dataset))
print(dataset[0])
  • 每个 dataset 对象是一系列 graph 的集合

  • 可以通过索引 dataset[0] 获取一个 graph

2.2 torch_geometric.dataset 模块

torch_geometric.dataset 类常用属性

# graph related
print("No. of graphs: ",        len(dataset))
print("No. of graph classes: ", dataset.num_classes)
# node related
print("No. of node labels: ",     dataset.num_node_labels)
print("No. of node features: ",   dataset.num_features)
print("No. of node features: ",   dataset.num_node_features)
print("No. of node attributes: ", dataset.num_node_attributes)
# edge related
print("No. of edge labels: ",     dataset.num_edge_labels)
print("No. of edge features: ",   dataset.num_edge_features)
print("No. of edge attributes: ", dataset.num_edge_attributes)

2.3 torch_geometric.data 模块

2.3.1 torch_geometric.data 类常用属性

data = dataset[0]
# graph related
print(data)
print("Graph is directed:",       data.is_directed())
print("Graph is undirected:",     data.is_undirected())
print("Contains isolated nodes:", data.has_isolated_nodes())
print("Contains self-loops:",     data.has_self_loops())
# node related
print("No. of nodes: ",          data.num_nodes)
print("No. of node features: ",  data.num_features)
print("No. of node features: ",  data.num_node_features)
# edge related
print("No. of edges: ",          data.num_edges)
print("No. of edge features: ",  data.num_edge_features)

print("No. of training nodes: ", data.train_mask.sum().item())

print(data.y.unique().size()[0])

2.4 自定义数据集

2.4.1 自定义 data

参考资料:Data Handling of Graphs, PyG Tutorial

实例代码: 创建一个 Data 类型的数据

import torch
from torch_geometric.data import Data

# edge index : (2, E)
# two undirected edges: 0-1, 1-2
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
# 等价于:
# edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long)
# edge_index = edge_index.t().contiguous()

# node feature : (N, D)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

# 
data_1 = Data(x=x, edge_index=edge_index)

实例代码: undirected graph 和 directed graph

# three directed edges: 0->1, 1->0, 1->2
edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long)
data_2 = Data(x=x, edge_index=edge_index)

print("Graph 1 is directed: ",  data_1.is_directed())  # False
print("Graph 2 is directed: ",  data_2.is_directed())  # True
  • edge_index: 为 edge 信息

    • Size 为 \((2, |\mathcal{E}|)\)tensor,数据类型为 torch.long

    • 对于 undirected graph,同一个 edge 的正向和反方都应包括在内

2.4.2 自定义 datasets

(1) 保存 Dataset

class MyDataset(InMemoryDataset):
    def __init__(self, root, data_list):
        self.data_list = data_list
        super().__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        torch.save(self.collate(self.data_list), self.processed_paths[0])


path = "/mydatasets"                 # 保存的数据文件路径 
data_list = [data_1, data_2, ...]    # 一系列 graph 数据,torch_geometric.data.Data 类型
dataset = MyDataset(root=path, data_list=data_list)  # 初始化时,自动保存数据文件
print(dataset[0], len(dataset))

(2) 读取 Dataset

class MyDataset(InMemoryDataset):
    def __init__(self, root):
        super().__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def processed_file_names(self):
        return ['data.pt']


path = "/mydatasets"              # 保存的数据文件路径 
dataset = MyDataset(root=path)    # 初始化时,自动加载数据文件
print(dataset[0], len(dataset))

3. torch_geometric.utils 模块

3.1 edge index 和 adjacency matrix

to_dense_adj(edge_index, **args): edge_index -> adjacency matrix

dense_to_sparse(adj, **args): adjacency matrix -> edge_index

实例代码to_dense_adj() 函数和 dense_to_sparse() 函数

import torch_geometric.utils as tgu

data = dataset[0]
adj = tgu.to_dense_adj(edge_index = data.edge_index)
print('size of adj:', adj.size())

edge_index = tgu.dense_to_sparse(adj)
print('type of edge_index:', type(edge_index))
edge_index = edge_index[0]
print('size of edge_index[0]:', edge_index.size())

3.2 Datanetworkx.Graph

to_networkx(data, **args): Data -> networkx.Graph or networkx.DiGraph

from_networkx(G): networkx.Graph or networkx.DiGraph -> Data

实例代码

import networkx as nx

G = tgu.to_networkx(data)
pos = nx.spring_layout(G, seed=648)  # Seed layout for reproducible node positions
nx.draw(G, pos, with_labels=True)

G = nx.Graph()
G.add_weighted_edges_from([(0, 1, 0.6), (1, 2, 0.2), (1, 3, 0.1)])
data = tgu.from_networkx(G)
print(data)
posted @ 2022-07-09 10:50  veager  阅读(545)  评论(0)    收藏  举报