GNN
这篇文章使用经典的 Cora 引文网络数据集来训练一个简单的 GCN(Graph Convolutional Network),完成一个半监督节点分类任务。
安装依赖:
pip install torch -i https://download.pytorch.org/whl/cu128
pip install torch_geometric
训练代码:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
# 加载 Cora 数据集,并对特征进行归一化处理
dataset = Planetoid(root="data/Planetoid", name="Cora", transform=NormalizeFeatures())
data = dataset[0]
# 定义一个简单的两层 GCN 模型
class GCN(th.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, hidden_channels) # 第一层:将输入特征映射到 hidden_channels
self.relu = nn.ReLU()
self.conv2 = GCNConv(hidden_channels, dataset.num_classes) # 第二层:将隐藏层再映射到类别数量
def forward(self, x, edge_index):
x = self.relu(self.conv1(x, edge_index)) # 第一层卷积 + ReLU
x = self.conv2(x, edge_index) # 第二层卷积(输出层)
return x
# 创建模型
device = th.device("cuda" if th.cuda.is_available() else "cpu")
model = GCN(hidden_channels=16).to(device)
data = data.to(device)
optimizer = Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练
for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch: {epoch}, Loss: {loss:.4f}")
# 测试
model.eval()
logits = model(data.x, data.edge_index)
preds = logits.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
acc = (preds[mask] == data.y[mask]).sum() / mask.sum()
accs.append(acc.item())
print(f"Train Acc: {accs[0]:.4f}, Val Acc: {accs[1]:.4f}, Test Acc: {accs[2]:.4f}")
如果使用公式描述,单层 GCN 的更新操作可以表示为:
\[H^{(l+1)} = \sigma \Bigl(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)} \Bigr)
\]
其中,\(\tilde{A} = A + I\)(在邻接矩阵上加单位矩阵)表示加入自连接操作,\(\tilde{D}\) 是 \(\tilde{A}\) 的度矩阵,\(\sigma\) 表示激活函数(如 ReLU)。通过多层堆叠,就能让图卷积网络在邻居节点之间传递与聚合信息,从而完成节点分类或其他图任务。

浙公网安备 33010602011771号