论文信息
论文标题:Graph Rationalization with Environment-based Augmentations
论文作者:王宇杰、于奎、张玉宏、曹付源、梁吉业
论文来源:KDD 2022
发布时间:2022
论文地址:link
论文代码:link
1 研究动机&&研究问题
-
领域重要性:图神经网络(GNN)在化学信息学(分子属性预测)、材料信息学(聚合物属性预测)等领域应用广泛。例如,分子属性预测可辅助药物研发,聚合物属性预测(如氧气渗透性、玻璃化转变温度)能加速高性能材料(如气体分离膜、耐高温聚合物)的发现,解决工程与环境领域的关键挑战(如 biodegradability、高温稳定性需求)。
-
数据痛点:这类任务的核心瓶颈是数据集规模小—— 分子基准数据集通常仅 1000-10000 个图,聚合物数据集更小(如 O₂Perm 仅 595 个样本)。小数据导致 GNN 模型易过拟合、泛化能力差,难以稳定学习图结构与属性间的因果关系。
如何设计一种数据增强驱动的图合理化框架,在小样本场景下,高效、准确地识别图的关键子图,同时提升 GNN 模型的预测性能、泛化能力与可解释性?
Q1:如何利用 “关键子图 - 环境子图” 的分离特性,设计针对性的数据增强策略,为图合理化生成有效虚拟样本?
Q2:如何避免显式子图解码 / 编码的高复杂度,在 latent 空间中完成关键子图 - 环境子图的分离与表征学习?
Q3:该框架在真实分子 / 聚合物数据集上,是否能同时优于现有图池化、泛化优化、图合理化方法(如 DIR、OOD-GNN)?
Q4:框架识别的关键子图是否符合领域知识(如聚合物化学中功能基团与气体渗透性的关系),具备实际可解释性?
Q5:框架对超参数(如关键子图大小、聚合函数)是否敏感,能否保持稳定性能?
GREA(Graph Rationalization with Environment-based Augmentations)是一种基于环境增强的图合理化框架,核心目标是在小样本场景下,高效分离图的 “关键子图(Rationale)” 与 “环境子图(Environment)”,同时通过数据增强提升模型的预测准确性、泛化能力与可解释性。
-
避免显式子图解码 / 编码:在 latent 空间 中完成关键子图 - 环境子图的分离、增强样本生成与表征学习,降低计算复杂度;
-
双增强驱动训练:结合 “环境移除增强” 和 “环境替换增强”,利用环境子图的 “自然噪声” 特性,为关键子图识别提供多样化训练信号;
-
交替优化策略:分别训练 “分离器(fsep)” 和 “预测器(fpred)”,平衡关键子图分离精度与属性预测性能。
-
分离器( $GNN_1 + MLP_1$ ):对输入图进行节点级掩码预测,实现关键子图与环境子图的初步分离;
-
表征生成器( $GNN_2$ ):生成图的节点上下文表征,为子图表征计算提供基础;
-
增强样本生成:在 latent 空间中生成 “环境移除样本”(仅关键子图)和 “环境替换样本”(关键子图 + 其他图的环境子图);
-
预测器( $MLP_2$ ):基于两类增强样本联合训练,优化整体损失函数,输出最终预测结果。
通过可学习的掩码机制,在 latent 空间中精准分离输入图的关键子图(因果部分)与环境子图(非因果部分),无需显式构建子图结构。
$m = \sigma\left(MLP_1(GNN_1(g))\right)$
其中:
-
-
-
$GNN_1$ :编码器,生成节点的 latent 表征(捕捉节点特征与局部拓扑);
-
$MLP_1$ :解码器,将节点 latent 表征映射为一维概率值;
-
$\sigma$ :sigmoid 激活函数,确保掩码值 $m_v \in (0,1)$(即节点 $v$ 属于关键子图的概率);
- 环境掩码: $1_N - m$ ( $1_N$ 为 N 维全 1 列向量),表示节点属于环境子图的概率。
$h^{(r)} = 1_N^{\top} \cdot (m \times H)$
$h^{(e)} = 1_N^{\top} \cdot \left( (1_N - m) \times H \right)$
其中, $h^{(r)}, h^{(e)} \in \mathbb{R}^d$ (d 为表征维度),Sum Pooling 确保聚合后的表征保留子图的全局信息。
- 批次训练设定:设训练批次中包含 $B$ 个图 $g_1, g_2, ..., g_B$ ,通过 3.2 节方法已得到每个图的关键子图表征 $h_i^{(r)}$ 和环境子图表征 $h_i^{(e)}$ ( $i = 1,2,...,B$ );
关键子图作为图属性的因果核心,仅用其表征应能实现与原图接近的预测性能。通过 “移除环境子图”,仅保留关键子图用于训练,强化关键子图的预测能力。
给定图 $g_i$ 的关键子图表征 $h_i^{(r)}$ ,预测器输出:
$\hat{y}_i^{(r)} = MLP_2\left(h_i^{(r)}\right)$
其中 $MLP_2$ 为属性预测器的解码器,与 3.1 节图属性预测器的 MLP 结构一致。
环境子图是无关噪声,若将图 $g_i$ 的关键子图与其他图 $g_j$ ( $j \neq i$ )的环境子图组合,生成的虚拟样本应与 $g_i$ 具有相同标签(因关键子图未变)。通过这种替换,模型能学习到 “关键子图不变则标签不变” 的因果规律,忽略环境噪声干扰。
-
表征聚合:通过聚合函数 $AGG(\cdot, \cdot)$ 组合 $h_i^{(r)}$ ( $g_i$ 的关键子图表征)与 $h_j^{(e)}$ ( $g_j$ 的环境子图表征),得到虚拟样本的表征 $h_{(i,j)}$ ;
-
聚合函数可选:求和池化(Sum Pooling,默认)、平均池化(Mean Pooling)、最大池化(Max Pooling)、拼接(Concatenation)等,公式示例(Sum Pooling):
$h_{(i,j)} = AGG\left(h_i^{(r)}, h_j^{(e)}\right) = h_i^{(r)} + h_j^{(e)}$
-
预测计算:虚拟样本的预测标签应与 $g_i$ 的真实标签 $y_i$ 一致,预测公式:
$\hat{y}_{(i,j)} = MLP_2\left(h_{(i,j)}\right)$
-
样本数量:每个图 $g_i$ 可与批次中其他 $B-1$ 个图的环境子图组合,生成 $B-1$ 个虚拟样本,显著提升训练数据多样性。
-
针对性强:专为图合理化设计,直接关联关键子图与环境子图的分离逻辑;
-
无额外标注成本:虚拟样本的标签由原关键子图的标签继承,无需人工标注;
-
兼容性高:在 latent 空间中实现,不依赖图的具体结构,适用于分子、聚合物等各类图数据。
损失函数分为三类,分别对应增强样本预测、关键子图大小正则化,确保模型兼顾预测精度与关键子图合理性。
-
目标:优化 “环境移除样本” 的预测精度,确保关键子图具备独立预测能力;
-
公式(以二分类任务为例,采用交叉熵损失):
$\mathcal{L}_{rem} = y_i \cdot \log \hat{y}_i^{(r)} + (1 - y_i) \cdot \log (1 - \hat{y}_i^{(r)})$
-
回归任务适配:若为连续属性预测(如聚合物密度),则替换为均方误差(MSE)损失。
-
目标:优化 “环境替换样本” 的预测精度,确保模型忽略环境噪声,聚焦关键子图;
-
公式(以二分类任务为例):
$\mathcal{L}_{rep} = \frac{1}{B} \sum_{j=1}^{B} \left( y_i \cdot \log \hat{y}_{(i,j)} + (1 - y_i) \cdot \log (1 - \hat{y}_{(i,j)}) \right)$
其中 $\frac{1}{B}$ 为批次内虚拟样本的损失均值,平衡不同虚拟样本的贡献。
- 预测器损失(训练 $GNN_2 + MLP_2$ ):
$\mathcal{L}_{pred} = \mathcal{L}_{rem} + \alpha \cdot \mathcal{L}_{rep}$
其中,$\alpha$ 为超参数,控制环境替换损失的权重;
$\mathcal{L}_{sep} = \mathcal{L}_{rem} + \alpha \cdot \mathcal{L}_{rep} + \beta \cdot \mathcal{L}_{reg}$
其中,$\beta$ 为超参数,控制正则化损失的权重。
-
核心逻辑:分离器与预测器存在相互依赖(分离器的掩码质量影响预测器性能,预测器的损失反馈优化分离器),因此采用交替训练:
-
固定分离器 $f_{sep}$ ,训练预测器 $f_{pred}$ 共 $T_{pred}$ 个 epoch;
-
固定预测器 $f_{pred}$ ,训练分离器 $f_{sep}$ 共 $T_{sep}$ 个 epoch;
-
重复上述步骤,直至模型收敛。
-
超参数设置: $T_{sep} \in \{1,2\}$ , $T_{pred} \in \{2,3\}$ (通过验证集调优),确保训练稳定且高效。
-
推理时,仅需通过分离器 $f_{sep}$ 得到输入图的关键子图表征 $h^{(r)}$ ,代入预测器 $MLP_2$ 输出最终预测结果:
$\hat{y} = MLP_2\left(h^{(r)}\right)$
-
关键子图可视化:通过掩码 $m$ 筛选概率高于阈值(如 0.5)的节点,构建关键子图结构,用于可解释性分析。