GNN

这篇文章使用经典的 Cora 引文网络数据集来训练一个简单的 GCN(Graph Convolutional Network),完成一个半监督节点分类任务。

安装依赖:

pip install torch -i https://download.pytorch.org/whl/cu128
pip install torch_geometric

训练代码:

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures


# 加载 Cora 数据集,并对特征进行归一化处理
dataset = Planetoid(root="data/Planetoid", name="Cora", transform=NormalizeFeatures())
data = dataset[0]


# 定义一个简单的两层 GCN 模型
class GCN(th.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)  # 第一层:将输入特征映射到 hidden_channels
        self.relu = nn.ReLU()
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)   # 第二层:将隐藏层再映射到类别数量

    def forward(self, x, edge_index):
        x = self.relu(self.conv1(x, edge_index))  # 第一层卷积 + ReLU
        x = self.conv2(x, edge_index)             # 第二层卷积(输出层)
        return x


# 创建模型
device = th.device("cuda" if th.cuda.is_available() else "cpu")
model = GCN(hidden_channels=16).to(device)
data = data.to(device)
optimizer = Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 训练
for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch: {epoch}, Loss: {loss:.4f}")

# 测试
model.eval()
logits = model(data.x, data.edge_index)
preds = logits.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
    acc = (preds[mask] == data.y[mask]).sum() / mask.sum()
    accs.append(acc.item())

print(f"Train Acc: {accs[0]:.4f}, Val Acc: {accs[1]:.4f}, Test Acc: {accs[2]:.4f}")

如果使用公式描述,单层 GCN 的更新操作可以表示为:

\[H^{(l+1)} = \sigma \Bigl(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)} \Bigr) \]

其中,\(\tilde{A} = A + I\)(在邻接矩阵上加单位矩阵)表示加入自连接操作,\(\tilde{D}\)\(\tilde{A}\) 的度矩阵,\(\sigma\) 表示激活函数(如 ReLU)。通过多层堆叠,就能让图卷积网络在邻居节点之间传递与聚合信息,从而完成节点分类或其他图任务。

posted @ 2025-06-20 21:11  Undefined443  阅读(19)  评论(0)    收藏  举报