Paper Reading: High dimensional, tabular deep learning with an auxiliary knowledge graph
Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。
| 论文概况 | 详细 |
|---|---|
| 标题 | 《High dimensional, tabular deep learning with an auxiliary knowledge graph》 |
| 作者 | Camilo Ruiz, Hongyu Ren, Kexin Huang, Jure Leskovec |
| 发表会议 | 37th Conference on Neural Information Processing Systems (NeurIPS 2023) |
| 发表年份 | 2023 |
| 会议等级 | CCF-A |
| 论文代码 | https://github.com/snap-stanford/plato |
作者单位:
- Department of Computer Science, Stanford University
- Department of Bioengineering, Stanford University
研究动机
在计算机视觉和自然语言处理等领域,机器学习模型在拥有大量标注数据时表现卓越。然而,在科学和生物医学等许多关键领域,获取的表格数据集常常具有极高的特征维度,但可用的标记样本数量却非常有限。这是因为实验通常耗时、昂贵且复杂,如药物筛选、基因测序等数据。当前主流的表格深度学习方法(如基于 MLP、Transformer、可微决策树的模型)主要针对样本数量远多于特征数量的场景,在特征数远多于样本数的高维稀疏场景下,这些深度模型拥有大量可训练参数,极易过拟合,导致性能不佳。
在特征维度远超样本数的设定下,主导方法仍是传统的统计机器学习方法,如降维、特征选择、对参数施加正则化惩罚的线性模型、决策森林。这些方法虽然能一定程度上应对过拟合,但可能无法充分捕捉复杂特征间的关系,且通常未能充分利用描述这些特征本身的丰富先验领域知识。因此本文旨在解决一个问题:如何为特征维度远大于样本数量的高维表格数据,设计一个有效的深度学习方法?更具体地,论文聚焦于满足以下两个条件的场景:
| 条件 | 说明 |
|---|---|
| 数据条件 | 存在一个表格数据集,其中特征数量 d 远大于样本数量 n。 |
| 知识条件 | 存在一个辅助的、异构的知识图谱,其中包含了描述这些输入特征的丰富领域信息。表格中的每个输入特征都对应于该知识图谱中的一个节点。KG 不仅包含这些特征节点,还包含代表更广泛领域概念的非特征节点以及多种关系类型。 |
文章贡献
在生物医学等领域的很多数据都是高维小样本表格数据(\(d\gg n\)),同时这类数据往往伴随着丰富的、描述特征间关系的领域知识。本文提出了一种深度学习框架 PLATO,其核心在于利用辅助知识图谱来正则化 MLP 的第一层权重。该算法首先在描述输入特征的异质 KG 上,通过自监督学习为每个特征预训练一个通用嵌入。接着引入一个可训练的、基于注意力机制的消息传递函数,结合当前输入样本的值,在监督信号的驱动下更新这些特征嵌入,使其适应具体任务。最后,算法通过一个在所有特征间共享的小型神经网络,从更新后的特征嵌入中推断出 MLP 第一层的权重矩阵。这一过程利用了“KG 中相似的特征节点应具有相似的模型权重”的归纳偏置,并将需要学习的参数量从 \(O(dh)\) 减少到 \(O(ch)\)(其中\(c \ll d\)),缓解了该场景下的过拟合问题。在 6 个真实的 \(d\gg n\) 生物数据集上的实验表明,PLATO 的性能超过了包括传统统计方法、特征选择、图正则化、决策树集成以及现有表格深度学习模型在内的 13 种基线方法,消融实验验证了了其设计的必要性。
本文方法
PLATO 是一种用于处理高维、小样本(\(d\gg n\))表格数据,并利用辅助知识图谱(KG)的机器学习方法。其核心思想在于利用描述输入特征的丰富领域信息(以 KG 形式结构化)来正则化一个 MLP,从而在 \(d\gg n\) 场景下取得优异性能。

问题设定
给定一个表格数据集 \(X\in R^{n\times d}\) 和标签 \(y\in R^{n}\),满足特征数 \(d\) 远大于样本数 \(n\)(\(d\gg n\)),目标是训练模型 \(\mathcal{F}\) 从输入 \(X\) 预测标签 \(\hat{y}\)。假设存在一个辅助知识图谱 \(G=(V, E)\),其中每个输入特征 \(j\) 都对应图 \(G\) 中的一个节点,即:\(\forall j\in\{1,\ldots, d\},\exists v\in V\)使得 \(j\mapsto v\)。知识图谱 \(G\) 不仅包含特征节点,还包含描述领域更广泛知识的其他节点,图中的边是(头节点,关系类型,尾节点)三元组。
PLATO 的归纳偏置
在 MLP 的第一层,每个输入特征 \(j\) 对应一个权重向量 \(\Theta_{j}^{[1]}\in R^{h}\),所有特征的权重向量构成第一层的权重矩阵 \(\Theta^{[1]}\in R^{d\times h}\)。PLATO 基于一个关键归纳偏置:在 KG 中相似的输入特征节点,其在 MLP 第一层中对应的权重向量也应该相似。PLATO 通过一个可训练的消息传递函数,从每个特征在 KG 中对应节点的信息来推断其权重向量,以此捕捉该偏置。PLATO 包含四个步骤:
| 序号 | 步骤 | 说明 |
|---|---|---|
| 1 | 预训练特征嵌入 | 在辅助 KG 上通过自监督学习为目标,为每个输入特征预训练一个嵌入向量 |
| 2 | 更新特征嵌入 | 使用一个在表格数据监督损失上训练的可训练消息传递函数,更新上一步得到的特征嵌入 |
| 3 | 推断 MLP 第一层权重 | 通过一个在特征间共享的小型神经网络,从更新后的特征嵌入推断 MLP 第一层的权重 |
| 4 | 预测 | MLP 第一层权重来自推断,其余层权重正常训练,对输入样本进行预测 |

自监督预训练特征嵌入
该步骤从 KG \(G\) 中学习关于每个输入特征 \(j\) 的一般性先验信息,表示为低维嵌入 \(M_{j}\in R^{c}\)。形式化为 \(M=\mathcal{H}(G)\),其中\(M\in R^{d\times c}\)为所有特征嵌入的矩阵,\(\mathcal{H}\) 是自监督节点嵌入方法,如文中使用的 ComplEx。此步骤为预训练,仅使用 KG 不涉及表格数据 \(X, y\),预训练后的特征嵌入 \(M\) 固定。
表格消息传递更新特征嵌入
此步骤通过一个可训练的消息传递函数\(\mathcal{Q}\),基于表格数据的监督损失,更新预训练的特征嵌入\(M_j\)为\(Q_j\)。形式化为 \(Q=\mathcal{Q}\left(M, G, X_{i};\Pi\right)\),函数输入包括:预训练特征嵌入 \(M\)、KG \(G\) 和当前样本值 \(X_i\),可训练权重 \(\Pi\) 在注意力机制中。初始化操作为 \(Q_{j}^{[0]}=M_{j}\),然后进行 \(R\) 轮消息传递,每轮更新公式为:
该公式涉及的符号及其含义如下表所示:
| 符号 | 含义 |
|---|---|
| \(\sigma\) | 可选非线性激活函数 |
| \(N_j\) | 特征节点 \(j\) 在 \(G\) 中的邻居。 |
| \(\beta\) | 超参数,控制来自邻居和来自节点自身信息的权重 |
| \(\alpha_{ijk}\) | 注意力系数,决定了对于样本 \(i\),邻居 \(k\) 对节点 \(j\) 的重要性 |
PLATO 的注意力系数计算是受 GAT 的启发,利用样本值 \(X_{ij}\) 和 \(X_{ik}\) 计算节点 \(j\) 和 \(k\) 之间的原始重要性分数 \(e_{ijk} = \mathcal{A}(X_{ij}, X_{ik};\Pi)\),其中 \(\mathcal{A}\) 是一个参数为\(\Pi\)的共享浅层神经网络。然后对邻居做 softmax 归一化得到\(\alpha_{ijk}\):
经过 \(R\) 轮后,得到更新后的特征嵌入\(Q_j = Q_{j}^{[R]}\)。
特征嵌入推断 \(\mathcal{F}\) 第一层权重
对于 MLP 第一层的每个输入特征\(j\),PLATO 从其更新后的特征嵌入 \(Q_j\) 推断对应的权重向量 \(\hat{\Theta}_{j}^{[1]}\)。该步骤形式化为 \(\hat{\Theta}_{j}^{[1]}=\mathcal{B}(Q_{j} | X_{i}; \Phi)\),其中 \(\mathcal{B}\) 是一个参数为 \(\Phi\) 的共享浅层神经网络。\(Q_j\) 的计算依赖于输入样本 \(X_i\),因此推断出的权重是样本依赖的。
PLATO 的关键优势在于大幅减少可训练参数量:标准 MLP 在第一层有 \(d \times h\) 个需训练的参数;PLATO 通过共享网络 \(\mathcal{B}\),只需训练 \(\Phi\) 中的参数。假设 \(\mathcal{B}\) 为单层网络,则 \(|\Phi| = c \times h\)。由于特征嵌入维度 \(c\) 远小于特征数 \(d\),因此 \(|\Phi| = ch \ll dh\),实现了参数量的减少。
实验结果
数据集和实验设置
在 6 个 \(d\gg n\) 的表格数据集上进行评估(MNSCLC, CM, PDAC, BRCA, CRC, CH),特征数(\(d\))在 12k-20k 之间,样本数(\(n\))在 286-924 之间,\(d/n\) 比值在 19.7 到 52.2 之间。扩展测试了 4 个 \(d\sim n\) 的数据集上进行测试(ME, BC, SCLC, NSCLC),\(d/n\) 比值在 1.1 到 2.0 之间。所有数据集均与生物医学相关,如预测癌细胞系或肿瘤模型对药物的反应。输入特征为基因表达值,每个样本还有一个额外的、固定的药物特征向量(药物在 KG 中的嵌入)。使用一个整合了多源生物医学知识的大型异构图,包含 108,447 个节点、3,066,156 条边和 99 种关系类型,所有数据集的基因和药物节点均包含在内。
与 13 种最先进的统计和深度学习方法进行对比,包括以下类别:
| 算法类型 | 对比模型 |
|---|---|
| 经典统计机器学习 | Ridge |
| 降维 | PCA + 线性回归 |
| 特征选择 | LASSO, 随机门 (STG) |
| 决策树 | XGBoost |
| 图正则化 | GraphNet, NC LASSO, Network LASSO (使用 KG 的特征子图) |
| 参数推断 | Diet Networks |
| 表格深度学习 | 标准 MLP, NODE, TabTransformer, TabNet (FT-Transformer 因内存问题无法运行) |
评估使用表格学习基准的评价标准,对每个模型在每个数据集上进行 500 次随机超参数搜索。数据按 60/20/20 划分训练、验证、测试集,报告在 3 个数据划分和每个划分上 3 次模型运行的均值±标准差,评估指标为预测值与真实值之间的皮尔逊相关系数。
对比实验
PLATO 在 \(d\gg n\) 场景下性能优于基线,结果可见 PLATO 在全部 6 个$ d\gg n$ 数据集上均取得了最佳性能。例如在 PDAC 数据集上,PLATO 相比最佳基线 XGBoost 的 PearsonR 提升了10.19%(0.400 vs. 0.363)。基线的表现高度依赖于具体数据集,没有一种方法能在所有数据集上保持稳定领先。例如,TabTransformer 在 MNSCLC 上表现第二,但在 PDAC 和 CH 上表现最差。这凸显了\(d\gg n\)问题的挑战性,也验证了 PLATO 利用 KG 先验知识的鲁棒性优势。

在 4 个 \(d\sim n\) 数据集上,PLATO 的表现与最强的基线 XGBoost 竞争激烈,互有胜负,但未表现出如在 \(d\gg n\) 场景下的显著优势。当样本量相对充足(\(d\sim n\))时,从数据本身可能已能训练出强预测模型(如 XGBoost),此时外部辅助信息(KG)的边际收益减小,印证了 PLATO 的核心价值在于解决数据稀缺(\(d\gg n\))的难题。
消融实验
对于可训练消息传递函数的重要性,在 BRCA 数据集上测试了 PLATO 的三个变体:
| PLATO 消息传递变体 | 说明 |
|---|---|
| 完整 PLATO | 使用可训练消息传递函数 \(\mathcal{Q}\) 更新后的特征嵌入 \(Q\) 来推断权重 |
| 使用预训练嵌入 | 直接使用预训练的特征嵌入\(M\)(未经消息传递更新)来推断权重 |
| 无 KG(标准 MLP) | 不使用任何 KG 信息,退化为标准 MLP |
结果可见完整 PLATO (0.583) > 使用预训练嵌入 (0.522) > 标准 MLP (0.240),这表明引入 KG 先验知识(\(M\))能显著提升性能,通过监督信号在具体任务上更新这些嵌入(\(Q\))是进一步提升性能的关键。

对于 KG 中非特征节点(领域知识)的重要性,在 BRCA 数据集上测试了 KG 的完备性:
| PLATO 领域知识变体 | 说明 |
|---|---|
| 完整 KG | 包含特征节点和其他领域知识节点 |
| 仅特征 KG | 只包含特征节点及其之间关系的子图 |
| 无 KG | 标准 MLP |
结果可见完整 KG (0.583) > 仅特征 KG (0.539) > 无 KG (0.240),这表明 KG 中超出特征直接关系的、更广泛的领域信息(如基因功能、生物过程等)对于模型性能有重要贡献。

对于不完整 KG 的鲁棒性,实验随机移除 KG 中不同比例的边,测试 PLATO 在 BRCA 上的性能。结果可见即使只保留 50% 的边,PLATO 仍能保持完整性能的 71%(0.412 vs. 0.583),这表明 PLATO 采用的 KG 嵌入和消息传递机制对 KG 中的缺失信息具有一定的鲁棒性。

对于 MLP 深层(可训练权重层)的重要性,实验比较了完整 PLATO(MLP,第一层权重推断,其余层训练)与 PLATO-LR(线性回归,所有权重推断)。结果可见完整 PLATO (0.583) > PLATO-LR (0.550),这说明在推断的第一层权重之上,增加可训练的非线性变换层(MLP 第 2 到 L 层)对于最终性能是必要的,推断权重本身并不足以捕获所有复杂模式。

补充结果
在参数量对比方面,结果可见 PLATO 通过权重推断,可训练参数量相比标准 MLP 显著减少(通常少一个数量级),缓解了 \(d\gg n\) 下的过拟合风险。

使用不同的 KG 嵌入方法(TransE, DistMult, ComplEx)预训练特征嵌入 \(M\),PLATO 的性能相近。表明方法对具体的 KG 嵌入模块选择不敏感,具有通用性。

在模型排名稳定性方面,PLATO 在 \(d\gg n\) 数据集上排名始终第一,而各基线的排名波动很大,进一步验证了 PLATO 的稳健性。

优点和创新点
个人认为,本文有如下一些优点和创新点可供参考学习:
- 本文为高维小样本(\(d\gg n\))表格数据提出了一种可结合异质辅助 KG 进行深度学习的新问题框架,其中每个输入特征都对应于知识图谱中的一个节点。
- 提出了 PLATO 模型,其核心创新在于通过一个可训练的、样本依赖的消息传递函数从 KG 中推断出 MLP 第一层的全部权重,从而将 KG 节点相似性的领域知识转化为模型参数相似性的归纳偏置,实现了高效的正则化。

浙公网安备 33010602011771号