从零开始构建图注意力网络:GAT算法原理与数值实现详解
图数据在机器学习中的地位越来越重要。社交网络的用户关系、论文引用网络、分子结构,这些都不是传统的表格或序列数据能很好处理的。现实世界中实体之间的连接往往承载着关键信息。
图神经网络(GNN)的出现解决了这个问题,它让每个节点可以从邻居那里获取信息来更新自己的表示。图卷积网络(GCN)是其中的经典代表,但GCN有个明显的限制:所有邻居节点的贡献都是相等的(在归一化之后)。
这个假设在很多情况下并不合理。比如在社交网络中,不同朋友对你的影响程度肯定不一样;在分子中,也不是所有原子对化学性质的贡献都相同。
图注意力网络(GAT)就是为了解决这个问题而设计的。它引入注意力机制,让模型自己学会给不同邻居分配不同的权重,而不是简单地平均处理。用一个比喻来说,GCN像是"听取所有朋友的建议然后求平均",而GAT更像是"重点听那些真正懂行的朋友的话"。
本文文会详细拆解GAT的工作机制,用一个具体的4节点图例来演示整个计算过程。如果你读过原论文觉得数学公式比较抽象,这里的数值例子应该能让你看清楚GAT到底是怎么运作的。
GAT的核心思想
GAT的设计目标很直接:让每个节点能够智能地选择从哪些邻居那里获取信息,以及获取多少信息。
任何图都包含三个基本要素:节点(V)代表图中的实体,边(E)表示实体间的关系,特征(X)是每个节点的属性向量。
GAT层的工作流程可以概括为:输入节点特征,通过线性变换投影到新的特征空间,计算节点间的注意力分数,用softmax进行归一化,最后按注意力权重聚合邻居信息得到新的节点表示。
我们用一个简单的4节点图来演示这个过程。节点A、B、C、D的连接关系如下图所示:

为了便于手工计算,我们设定每个节点的特征维度为3:
https://avoid.overfit.cn/post/b1c7efd4b1004512a98ebf3fcecce8e7

浙公网安备 33010602011771号