GAT入门

GAT入门(Graph Attention Networks)

信息来源:https://zhuanlan.zhihu.com/p/81350196

1.聊点基础

① 图的两种“特征”

  • 对于任意顶点,其和邻居构成第一种特征,即图的结构关系
  • 每个顶点都有自己的特征,即GCN中提到的图信号,它可以是社交网络中每个用户个体属性,每个蛋白质的性质等
  • 进行图神经网络无外乎就是学习以上的两种特征

②GCN的局限

  • 无法处理动态图问题,在GCN中,图的固定邻居矩阵和图信号作为GCN层的输入,这两者一般是固定不变的,而动态图中频繁更新图邻居矩阵会带来巨大的计算和设计难度
  • 无法处理有向图,在上一节GCN中,我们默认是在无向图(邻居矩阵为实对称阵)的前提下进行讨论的
    • 在GCN中,聚合是统一的,没有区分入度和出度
    • 而有向图中入度和出度往往具有不同的含义,比如论文中的引用和被引用,社交网络中的的关注和被关注,而GCN将二者一视同仁,无法表示方向的差异 - (每个邻居被赋予相同的重要性,这不合理!)

③GAT的两种运算方式

  • Global graph attention:对每一个顶点都对图上的任意顶点进行attention计算,即下图中蓝色节点对全部节点都进行注意力计算
    • 这种方式不依赖图结构(因为对所有节点都进行一次计算),本质上是对所有点的全局建模,理论上有更强的表达能力
    • 但图本身是有结构的,如果不利于的图结构,相当于自废武功,且每轮对所有节点计算时,产生的计算开销极大
  • Mask graph attention:注意力机制的计算只在邻居节点上进行,即下图蓝色节点仅对橙色节点进行注意力计算

img

  • 对比Attention is all your need的Mask

    • Transformer的Mask控制注意力计算是否允许访问某些位置,这和GAT的Mask类似,GAT的Mask使当前节点只计算其邻居节点的注意力,也算是一种只允许访问某些位置
    • GAT的Mask就像给注意力加了“地图”,明确告诉当前节点,哪些边是连通的,哪些是不连通的
    • Transformer的Mask主要是控制当前时刻不能访问其他时刻的信息
  • 以下的内容均在Mask graph Attention的前提下讨论

2.GAT原理

  • 核心思想:邻居特征的重要性可以通过注意力机制自适应学习

  • 和Transformer一样,GAT的计算也分“计算注意力系数”和“加权求和”两步骤

①计算注意力系数

  • 首先对所有节点做一层共享(参数)的线性变换$h_i^{(1)} = Wh_i$

    • 这里的目的将原始节点特征映射到一个更适合做注意力的空间(增维)

    • 类似Transformer的QKV映射:对于输入X,分别乘上不同的权重矩阵得到QKV三个矩阵,这里扩展回顾一下Transformer的QKV

      • 一方面是将输入特征映射到注意力交互空间
      • 另一方面提供更多的灵活性和表达能力:如果直接在输入特征上做注意力,表达能力有限,加入QK映射后,可以为不同的关注方向建模不同的子空间,相当于告诉模型,你在提问(Q)和被提问(K)时的语义可以不同
    • GAT的注意力是所有节点先做一个统一的线性变换,再通过一个共享的注意力网络打分

  • 计算节点i和j之间的注意力系数

    • $e_{ij} = LeakyReLU(\vec{\alpha ^T}[Wh_i || Wh_j])$

      • 首先$[Wh_i || Wh_j]$将顶点i,j变换后的特征进行拼接
      • $\alpha$将拼接后的高维特征映射到一个实数上,这个过程和Transformer一样,使用单层前馈神经网络实现,之后加一层非线性
    • 之后归一化(转换为概率(0,1))懒得敲了直接上图吧

      image-20250526135601962

    • 最后得到的$\alpha_{ij}$即对于节点i而言,其邻居j的重要程度,一个节点进行注意力系数计算后,就可以知道它的所有邻居中,哪些比较重要,哪些比较不重要,这样的作用就是,每个节点可以根据邻居的重要程度进行动态加权邻居的特征,而不是像之前的平均或加权求和

②加权求和

  • 即根据计算好的注意力系数,把特征加和求权一下
  • $h_i' = \sigma(\sum_{j\in N_i}\alpha_{ij}Wh_j) $ ,输出一个节点特征$h_i$,经GAT后,输出节点新的特征$h'_i$
  • 和Transformer一样,引入多头注意力效果更好

3.代码实战

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

#加载数据集
dataset = Planetoid(root='D:/tmp/Cora', name='Cora')
data = dataset[0]

#定义GAT
class GAT(torch.nn.Module):
    def __init__(self,in_channels,hidden_channels,out_channels):
        super().__init__()
        self.gat1 = GATConv(in_channels,hidden_channels,heads = 8,dropout = 0.6) #8个注意力头并行计算
        self.gat2 = GATConv(hidden_channels * 8,out_channels,heads = 1,concat = False,dropout = 0.6) #concat = True:将各头的输出在特征维度上拼接,False为将各头的输出取平均

    def forward(self,data):
        x,edge_index = data.x,data.edge_index
        x = F.dropout(x,p = 0.6,training = self.training)
        x = F.elu(self.gat1(x,edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

#初始化模型和优化器
model = GAT(dataset.num_node_features,8,dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(),lr = 0.005,weight_decay = 5e-4)

#训练过程
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

#测试过程
def test():
    model.eval()
    logits = model(data)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        pred = logits[mask].argmax(dim=1)
        acc = (pred == data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

# 跑训练
for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 20 == 0:
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
posted @ 2025-05-25 19:08  mumumu1  阅读(31)  评论(0)    收藏  举报