笔记:GAT入门学习

GAT图注意力网络

GAT 采用了 Attention 机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于 inductive 任务。

假设 Graph 包含 $N$ 个节点,每个节点的特征向量为 $h_i$,维度是 $F$,如下所示:

\begin{gathered}
\boldsymbol{h}=\left\{h_{1}, h_{2}, \ldots, h_{N}\right\} \\
h_{1} \in R^{F}
\end{gathered}

节点 $j$ 是节点 $i$ 的邻居,则可以使用 Attention 机制计算节点 $j$ 对于节点 $i$ 的重要性,即 Attention Score:

\begin{gathered}
e_{i j}=\operatorname{Attention}\left(W h_{i}, W h_{j}\right) \\
\alpha_{i j}=\operatorname{Softmax}_{j}\left(e_{i j}\right)=\frac{\exp \left(e_{i j}\right)}{\sum_{k \in N_{i}} \exp \left(e_{i k}\right)}
\end{gathered}

注意这个 $w$ 都是同一个

GAT 具体的 Attention 做法如下,把节点 $i、j$ 的特征向量 $h'_i$、$h'_j$ 拼接在一起,然后和一个 $2F'$ 维的向量 $a$ 计算内积。激活函数采用 LeakyReLU,公式如下:

$$
\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{j}\right]\right)\right)}{\sum_{k \in N_{i}} \exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{k}\right]\right)\right)}
$$
|| 表示拼接操作

经过 Attention 之后节点 $i$ 的特征向量如下:

$$h_{i}^{\prime}=\sigma\left(\sum_{j \in N_{i}} \alpha_{i j} W h_{j}\right)$$

GAT 也可以采用 Multi-Head Attention,如果有 K 个 Attention,则需要把 K 个 Attention 生成的向量拼接在一起,如下:

$$h_{i}^{\prime}=\operatorname{concat}\left(\sigma\left(\sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)\right)$$

但是如果是最后一层,则 K 个 Attention 的输出不进行拼接,而是求平均:

$$h_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)$$

网络结构:

样例来自 https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py

class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAT, self).__init__()

        # num_features: Alias for num_node_features.
        self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)

        # On the Pubmed dataset, use heads=8 in conv2.
        self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False,
                             dropout=0.6)

    def forward(self, x, edge_index):
        ipdb.set_trace()
        x_copy = x.clone()
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x + x_copy  # Residual connection, 避免孤立节点变成全0
        
        # return F.log_softmax(x, dim=-1)  # log_softmax ??
        return x   #  我觉得这个位置还不要softmax

 

参考链接:https://ai.baidu.com/forum/topic/show/972764

posted @ 2021-11-12 20:59  Rogn  阅读(373)  评论(0编辑  收藏  举报