Loading

图预训练方法

STRATEGIES FOR PRE-TRAINING GRAPH NEURAL NETWORKS

原文链接:STRATEGIES FOR PRE-TRAINING GRAPH NEURAL NETWORKS

Abstract

本文主要提出了一个自监督学习的方法,用来做 gnn 的与训练,这个策略成功的关键是在单个节点和整个图的层次上预先训练一个到代表性的 GNN,以便 GNN 能够同时学习有用的局部和全局表示。作者发现传统的预训练策略,无论是在 graph 级别还是节点级别,都无法有效的提升下游任务的表现,甚至可能起到反效果。而作者的策略能有效的避免这个反效果,并且在分子属性预测和蛋白质功能预测上,提高了 9.4% 的 AUC。

Introduction

特定任务的标签是非常稀缺的。真实世界中,图数据通常包含与原始数据分布不一致的样本点,意思就是图数据中,训练集与测试集的数据分布很可能不一致。这是真实世界的图数据的一个普遍现象。以往的一些研究表明 (Xu et al., 2017; Ching et al., 2018; Wang et al., 2019). 一个成功的迁移学习不仅仅是增加与下游任务来自同一领域的标注好的预训练数据集的数量。相反,它需要大量的领域专业知识来仔细选择与感兴趣的下游任务相关的样本和目标标签。否则,可能带来反效果,被称之为 negative transfer .

本文工作的两个核心贡献:
  1. 作者首次系统的探索了大规模 GNN 预训练。为此,作者建立了两个新的预训练数据集,并且分享出来:一个有 2M graph 的化学数据集和一个有 395K graph 的生物数据集。同时表明,大尺度的特定领域数据集的预训练研究是至关重要的,现有的下游任务数据集太小,无法以统计上可靠的方式进行模型评估。
  2. 提出了一种有效的 GNN 预训练策略,并证明了该策略对于 hard transfer-learning 的有效性以及分布外的泛化能力。

值得注意的是,一个看似强大的预训练策略(即,对于图级预测任务,使用最先进的图神经网络架构的图级多任务监督预训练)只会带来边际性能增益。(marginal performance: 这个词的意思,综合可以理解成效果不好,增加的有限)。而且,还有可能带来反效果 (negative transfer)

本文的高效预训练策略,主要思想是用易于得到的节点级别的信息,去让 GNN 捕获节点与边以及图级别的特定领域的知识。这有助于 GNN 学习到全局和局部级别的有效表示,并且至关重要的是,能够迁移到不同的下游任务的图级(通过节点表示的集成)表示。

本文的策略与朴素的策略相比,最大的不同是朴素的策略要么只用在节点级别的表示,要么只捕获图级别的表示。

本文的预训练策略与表现最好的 GNN 结构 GIN 一起使用,在基准数据集上取得了 SOTA 的结果表现,并且避免了负迁移。另外,作者发现 GIN 在预训练的收益要比其他模型如 GCN,GraphSAGE 和 GAT 等等要高,并且预训练的 GNN 在微调阶段训练收敛速度高了一个量级

Methods

预训练策略的核心是同时在独立的节点级别与全图级别去预训练 GNN。下面接着介绍本文的节点级别预训练与图级别预训练,最后介绍整个预训练策略

针对节点级别的预训练 Node-Level Pre-Training

对于节点级别的预训练,本文的方法是用那些容易获得的没有标注的数据,来捕捉特定领域的知识与规则。下面提出了两个自学习策略,Context PredictionAttribute Masking.

Node-Level Pre Training

Context Prediction:Exploiting distribution of graph structure

在这部分,我们用子图去预测周围的图结构。这里的预训练的目标是能够使得相似结构的节点的 embedding 相似。这一部分学习的 embedding 方式类似于 graphsage 和 deep graph infomax 的负采样,简单来说就是定义了一个 context graph,用他的 embeding 和邻居做训练,认为两者相似,而不同节点的邻居和 context graph 不相似。

Neighborhood and context graphs.

For every node \(v\), we define \(v\) 's neighborhood and context graphs as follows. [首先定义针对节点\(V\)\(K\)阶邻居以及与之相关的概念\(\text{context graph}\) ]. \(K\)-hop neighborhood of \(v\) contains all nodes and edges that are at most \(K\)-hops away from \(v\) in the graph. This is motivated by the fact that a \(K\)-layer GNN aggregates information across the \(K\)-th order neighborhood of \(v\), and thus node embedding \(h_{v}^{(K)}\) depends on nodes that are at most \(K\)-hops away from \(v\).

We define context graph of node \(v\) as graph structure that surrounds \(v\) 's neighborhood. [这里定义了\(\text{context graph}\)的相关概念.]The context graph is described by two hyperparameters, \(r_{1}\) and \(r_{2}\), and it represents a subgraph that is between \(r_{1}\)-hops and \(r_{2}\)-hops away from \(v\) (i.e., it is a ring of width \(r_{2}-r_{1}\) ). Examples of neighborhood and context graphs are shown in Figure \(2(\mathrm{a}) .\) We require \(r_{1}<K\) so that some nodes are shared between the neighborhood and the context graph, and we refer to those nodes as context anchor nodes.[\(\text{context anchor nodes}\) 是指: 既在\(r_1\)-hop内, 又在\(r_2\)-hop内的node节点,这类节点共享了关于中心节点\(v\)的邻居信息,及\(v\)\(\text{context anchor nodes}\) 相连接的信息] These anchor nodes provide information about how the neighborhood and context graphs are connected with each other.

Encoding context into a fixed vector using an auxiliary GNN.

Directly predicting the context graph is intractable due to the combinatorial nature of graphs. This is different from natural language processing, where words come from a fixed and finite vocabulary. To enable context prediction, we encode context graphs as fixed-length vectors. To this end, we use an auxiliary GNN, which we refer to as the context \(GNN\). As depicted in Figure \(2(\mathrm{a})\).

How To Get Context Embedding ?

  1. Firstly, apply the context GNN to obtain node embeddings in the context graph.
  2. Secondly, average embeddings of context anchor nodes to obtain a fixed-length context embedding.
  3. For node \(v\) in graph \(G\), we denote its corresponding context embedding as \(c_{v}^{G}\).
Learning via negative sampling.

We then use negative sampling (Mikolov et al., 2013; Ying et al., 2018a) to jointly learn the main GNN and the context GNN.
The main GNN encodes neighborhoods to obtain node embeddings.
The context GNN encodes context graphs to obtain context embeddings.
In particular, the learning objective of Context Prediction is a binary classification of whether a particular neighborhood and a particular context graph belong to the same node:

\[\sigma\left(h_{v}^{(K) \top} c_{v^{\prime}}^{G^{\prime}}\right) \approx \mathbf{1}\{ v \text{ and } v^{\prime} \text{are the same nodes} \} \]

where \(\sigma(\cdot)\) is the sigmoid function, and \(\mathbf{1}(\cdot)\) is the indicator function. We either let \(v^{\prime}=v\) and \(G^{\prime}=G\) (i.e., a positive neighborhood-context pair), or we randomly sample \(v^{\prime}\) from a randomly chosen graph \(G^{\prime}\) (i.e., a negative neighborhood-context pair). We use a negative sampling ratio of 1 (one negative pair per one positive pair), and use the negative log likelihood as the loss function. After pre-training, the main GNN is retained as our pre-trained model.


# 用于编码 substruct 的 model

model_substruct = GNN(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio,  gnn_type=args.gnn_type).to(device) 

# 用于编码 context 的 model
model_context = GNN(int(l2 - l1), args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio,  gnn_type=args.gnn_type).to(device)

# substruct 中心节点  
print("center_substruct_idx : {}".format(batch["center_substruct_idx"]))  
  
# context_substruct 重叠节点序号  
print("overlap_context_substruct_idx : {}".format(batch["overlap_context_substruct_idx"]))  
  
# overlapped_context batch定位  
print("batch_overlapped_context : {}".format(batch["batch_overlapped_context"]))  
  
# 每个图在batch中的数目  
print("overlapped_context_size: {}".format(batch["overlapped_context_size"]))

"""
center_substruct_idx : tensor([ 7, 13, 32, 47, 63, 73, 84])
overlap_context_substruct_idx : tensor([ 3,  4,  5,  7, 12, 13, 11, 17, 18, 19, 20, 22, 23, 28, 33, 34, 31, 32, 37])
batch_overlapped_context : tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 5, 5, 5, 5, 6])
overlapped_context_size: tensor([2, 2, 3, 6, 1, 4, 1])
"""

# creating substructure representation
# 中心节点的embedding 表示  
# TODO: 每一个graph只有一个center_substruct_idx
substruct_rep = model_substruct(batch.x_substruct, batch.edge_index_substruct, batch.edge_attr_substruct)[batch.center_substruct_idx]  
  
# creating context representations  
# context anchor representation  
# TODO:一个graph中可能有多个context anchor  
overlapped_node_rep = model_context(batch.x_context, batch.edge_index_context, batch.edge_attr_context)[batch.overlap_context_substruct_idx]




# Contexts are represented by  
if args.mode == "cbow":  
 # positive context representation  
 context_rep = pool_func(overlapped_node_rep, batch.batch_overlapped_context, mode=args.context_pooling)  
 # negative contexts are obtained by shifting the indicies of context embeddings  
 neg_context_rep = torch.cat(  
 [context_rep[cycle_index(len(context_rep), i + 1)] for i in range(args.neg_samples)], dim=0)  
  
 pred_pos = torch.sum(substruct_rep * context_rep, dim=1)  
 pred_neg = torch.sum(substruct_rep.repeat((args.neg_samples, 1)) * neg_context_rep, dim=1)  
  
elif args.mode == "skipgram":  
  
 expanded_substruct_rep = torch.cat(  
 [substruct_rep[i].repeat((batch.overlapped_context_size[i], 1)) for i in range(len(substruct_rep))],  
        dim=0)  
 pred_pos = torch.sum(expanded_substruct_rep * overlapped_node_rep, dim=1)  
  
 # shift indices of substructures to create negative examples  
 shifted_expanded_substruct_rep = []  
 for i in range(args.neg_samples):  
 shifted_substruct_rep = substruct_rep[cycle_index(len(substruct_rep), i + 1)]  
 shifted_expanded_substruct_rep.append(torch.cat(  
 [shifted_substruct_rep[i].repeat((batch.overlapped_context_size[i], 1)) for i in  
 range(len(shifted_substruct_rep))], dim=0))  
  
 shifted_expanded_substruct_rep = torch.cat(shifted_expanded_substruct_rep, dim=0)  
 pred_neg = torch.sum(shifted_expanded_substruct_rep * overlapped_node_rep.repeat((args.neg_samples, 1)),  
                         dim=1)  
  
else:  
 raise ValueError("Invalid mode!")  
  
loss_pos = criterion(pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double())  
loss_neg = criterion(pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double())  
  
optimizer_substruct.zero_grad()  
optimizer_context.zero_grad()  
  
loss = loss_pos + args.neg_samples * loss_neg  
loss.backward()  
# To write: optimizer  
optimizer_substruct.step()  
optimizer_context.step()   

  1. 子图携带有一定的结构信息
  2. context涵盖了大图的一部分结构信息和边的信息
Attribute Masking: Exploiting distribution of graph structure

随机掩盖一些节点 / 边的特征,用特殊的标识代替,然后放进 gnn 学周围的 embedding,利用周围的 embedding 来预测这个特征。

In Attribute Masking, we aim to capture domain knowledge by learning the regularities of the node/edge attributes distributed over graph structure.

Masking node and edges attributes. Attribute Masking pre-training works as follows:

  1. We mask node/edge attributes and then we let GNNs predict those attributes (Devlin et al., 2019) based on neighboring structure. Figure 2 (b) illustrates our proposed method when applied to a molecular graph.
  2. Specifically, We randomly mask input node/edge attributes, for example atom types in molecular graphs, by replacing them with special masked indicators. We then apply GNNs to obtain the corresponding node/edge embeddings (edge embeddings can be obtained as a sum of node embeddings of the edge’s end nodes).
  3. Finally, a linear model is applied on top of embeddings to predict a masked node/edge attribute. Different from Devlin et al. (2019) that operates on sentences and applies message passing over the fully-connected graph of tokens, we operate on non-fully- connected graphs and aim to capture the regularities of node/edge attributes distributed over different graph structures. Furthermore, we allow masking edge attributes, going beyond masking node attributes.

Our node and edge attribute masking method is especially beneficial for richly-annotated graphs from scientific domains. For example, (1) in molecular graphs, the node attributes correspond to atom types, and capturing how they are distributed over the graphs enables GNNs to learn simple chemistry rules such as valency, as well as potentially more complex chemistry phenomenon such as the electronic or steric properties of functional groups. Similarly, (2) in protein-protein interaction (PPI) graphs, the edge attributes correspond to different kinds of interactions between a pair of proteins. Capturing how these attributes distribute across the PPI graphs enables GNNs to learn how different interactions relate and correlate with each other.

针对图级别的预训练 Graph-Level Pre-Training

本文目标是确保节点和图级别的 embedding 都是高质量的,所以可以鲁棒的迁移到下游任务。

Supervised Graph-Level Property Prediction

作者通过将在图级别预测任务去预训练 embedding 来得到特定的图级别任务领域知识。作者利用多任务联合学习不同的任务来预训练。
然而,朴素的直接多任务图级别的预训练可能在迁移的时候失效,因为这些任务可能与下游任务无关,造成 negative transfer。一个解决方案是选择真正相关的预训练任务,只拿这些任务预训练。然而,寻找这些任务代价非常大。
多任务学习的预训练只用来做图级别的预训练,因此,创造 graph embedding 的这些局部节点的 embedding 可能会变得无意义。这些无意义的节点会加剧 negative transfer,许多没有用的与训练任务会互相干扰。
所以,基于此,才有了之前先预训练节点级别的 embedding 的方法,作为正则,然后再进行 graph 级别的预训练。

Structual Simailarity Prediction

作者没做这个工作,只是觉得这是个不错的方向,写上来占坑,难点主要在于对于 graph 相似度的度量困难,目前没有一个比较好的定义。

[[GNN_Pretraining 代码解析]]

posted @ 2022-10-09 13:18  MarkL124950  阅读(319)  评论(0)    收藏  举报