Paper Reading: High dimensional, tabular deep learning with an auxiliary knowledge graph


Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。

论文概况 详细
标题 《High dimensional, tabular deep learning with an auxiliary knowledge graph》
作者 Camilo Ruiz, Hongyu Ren, Kexin Huang, Jure Leskovec
发表会议 37th Conference on Neural Information Processing Systems (NeurIPS 2023)
发表年份 2023
会议等级 CCF-A
论文代码 https://github.com/snap-stanford/plato

作者单位:

  1. Department of Computer Science, Stanford University
  2. Department of Bioengineering, Stanford University

研究动机

在计算机视觉和自然语言处理等领域,机器学习模型在拥有大量标注数据时表现卓越。然而,在科学和生物医学等许多关键领域,获取的表格数据集常常具有极高的特征维度,但可用的标记样本数量却非常有限。这是因为实验通常耗时、昂贵且复杂,如药物筛选、基因测序等数据。当前主流的表格深度学习方法(如基于 MLP、Transformer、可微决策树的模型)主要针对样本数量远多于特征数量的场景,在特征数远多于样本数的高维稀疏场景下,这些深度模型拥有大量可训练参数,极易过拟合,导致性能不佳。
在特征维度远超样本数的设定下,主导方法仍是传统的统计机器学习方法,如降维、特征选择、对参数施加正则化惩罚的线性模型、决策森林。这些方法虽然能一定程度上应对过拟合,但可能无法充分捕捉复杂特征间的关系,且通常未能充分利用描述这些特征本身的丰富先验领域知识。因此本文旨在解决一个问题:如何为特征维度远大于样本数量的高维表格数据,设计一个有效的深度学习方法?更具体地,论文聚焦于满足以下两个条件的场景:

条件 说明
数据条件 存在一个表格数据集,其中特征数量 d 远大于样本数量 n。
知识条件 存在一个辅助的、异构的知识图谱,其中包含了描述这些输入特征的丰富领域信息。表格中的每个输入特征都对应于该知识图谱中的一个节点。KG 不仅包含这些特征节点,还包含代表更广泛领域概念的非特征节点以及多种关系类型。

文章贡献

在生物医学等领域的很多数据都是高维小样本表格数据\(d\gg n\)),同时这类数据往往伴随着丰富的、描述特征间关系的领域知识。本文提出了一种深度学习框架 PLATO,其核心在于利用辅助知识图谱来正则化 MLP 的第一层权重。该算法首先在描述输入特征的异质 KG 上,通过自监督学习为每个特征预训练一个通用嵌入。接着引入一个可训练的、基于注意力机制的消息传递函数,结合当前输入样本的值,在监督信号的驱动下更新这些特征嵌入,使其适应具体任务。最后,算法通过一个在所有特征间共享的小型神经网络,从更新后的特征嵌入中推断出 MLP 第一层的权重矩阵。这一过程利用了“KG 中相似的特征节点应具有相似的模型权重”的归纳偏置,并将需要学习的参数量从 \(O(dh)\) 减少到 \(O(ch)\)(其中\(c \ll d\)),缓解了该场景下的过拟合问题。在 6 个真实的 \(d\gg n\) 生物数据集上的实验表明,PLATO 的性能超过了包括传统统计方法、特征选择、图正则化、决策树集成以及现有表格深度学习模型在内的 13 种基线方法,消融实验验证了了其设计的必要性。

本文方法

PLATO 是一种用于处理高维、小样本(\(d\gg n\))表格数据,并利用辅助知识图谱(KG)的机器学习方法。其核心思想在于利用描述输入特征的丰富领域信息(以 KG 形式结构化)来正则化一个 MLP,从而在 \(d\gg n\) 场景下取得优异性能。
image

问题设定

给定一个表格数据集 \(X\in R^{n\times d}\) 和标签 \(y\in R^{n}\),满足特征数 \(d\) 远大于样本数 \(n\)\(d\gg n\)),目标是训练模型 \(\mathcal{F}\) 从输入 \(X\) 预测标签 \(\hat{y}\)。假设存在一个辅助知识图谱 \(G=(V, E)\),其中每个输入特征 \(j\) 都对应图 \(G\) 中的一个节点,即:\(\forall j\in\{1,\ldots, d\},\exists v\in V\)使得 \(j\mapsto v\)。知识图谱 \(G\) 不仅包含特征节点,还包含描述领域更广泛知识的其他节点,图中的边是(头节点,关系类型,尾节点)三元组。

PLATO 的归纳偏置

在 MLP 的第一层,每个输入特征 \(j\) 对应一个权重向量 \(\Theta_{j}^{[1]}\in R^{h}\),所有特征的权重向量构成第一层的权重矩阵 \(\Theta^{[1]}\in R^{d\times h}\)。PLATO 基于一个关键归纳偏置:在 KG 中相似的输入特征节点,其在 MLP 第一层中对应的权重向量也应该相似。PLATO 通过一个可训练的消息传递函数,从每个特征在 KG 中对应节点的信息来推断其权重向量,以此捕捉该偏置。PLATO 包含四个步骤:

序号 步骤 说明
1 预训练特征嵌入 在辅助 KG 上通过自监督学习为目标,为每个输入特征预训练一个嵌入向量
2 更新特征嵌入 使用一个在表格数据监督损失上训练的可训练消息传递函数,更新上一步得到的特征嵌入
3 推断 MLP 第一层权重 通过一个在特征间共享的小型神经网络,从更新后的特征嵌入推断 MLP 第一层的权重
4 预测 MLP 第一层权重来自推断,其余层权重正常训练,对输入样本进行预测

image

自监督预训练特征嵌入

该步骤从 KG \(G\) 中学习关于每个输入特征 \(j\) 的一般性先验信息,表示为低维嵌入 \(M_{j}\in R^{c}\)。形式化为 \(M=\mathcal{H}(G)\),其中\(M\in R^{d\times c}\)为所有特征嵌入的矩阵,\(\mathcal{H}\) 是自监督节点嵌入方法,如文中使用的 ComplEx。此步骤为预训练,仅使用 KG 不涉及表格数据 \(X, y\),预训练后的特征嵌入 \(M\) 固定。

表格消息传递更新特征嵌入

此步骤通过一个可训练的消息传递函数\(\mathcal{Q}\),基于表格数据的监督损失,更新预训练的特征嵌入\(M_j\)\(Q_j\)。形式化为 \(Q=\mathcal{Q}\left(M, G, X_{i};\Pi\right)\),函数输入包括:预训练特征嵌入 \(M\)、KG \(G\) 和当前样本值 \(X_i\),可训练权重 \(\Pi\) 在注意力机制中。初始化操作为 \(Q_{j}^{[0]}=M_{j}\),然后进行 \(R\) 轮消息传递,每轮更新公式为:

\[Q_{j}^{[r]}=\sigma\left[\beta\left(\sum_{k\in N_{j}}\alpha_{ijk} Q_{k}^{[r-1]}\right) + (1-\beta)Q_{j}^{[r-1]}\right] \]

该公式涉及的符号及其含义如下表所示:

符号 含义
\(\sigma\) 可选非线性激活函数
\(N_j\) 特征节点 \(j\)\(G\) 中的邻居。
\(\beta\) 超参数,控制来自邻居和来自节点自身信息的权重
\(\alpha_{ijk}\) 注意力系数,决定了对于样本 \(i\),邻居 \(k\) 对节点 \(j\) 的重要性

PLATO 的注意力系数计算是受 GAT 的启发,利用样本值 \(X_{ij}\)\(X_{ik}\) 计算节点 \(j\)\(k\) 之间的原始重要性分数 \(e_{ijk} = \mathcal{A}(X_{ij}, X_{ik};\Pi)\),其中 \(\mathcal{A}\) 是一个参数为\(\Pi\)的共享浅层神经网络。然后对邻居做 softmax 归一化得到\(\alpha_{ijk}\)

\[\alpha_{i j k}=\operatorname{softmax}_{k}\left(e_{i j k}\right)=\frac{\exp\left(e_{i j k}\right)}{\sum_{t\in N_{j}}\exp\left(e_{i j t}\right)} \]

经过 \(R\) 轮后,得到更新后的特征嵌入\(Q_j = Q_{j}^{[R]}\)

特征嵌入推断 \(\mathcal{F}\) 第一层权重

对于 MLP 第一层的每个输入特征\(j\),PLATO 从其更新后的特征嵌入 \(Q_j\) 推断对应的权重向量 \(\hat{\Theta}_{j}^{[1]}\)。该步骤形式化为 \(\hat{\Theta}_{j}^{[1]}=\mathcal{B}(Q_{j} | X_{i}; \Phi)\),其中 \(\mathcal{B}\) 是一个参数为 \(\Phi\) 的共享浅层神经网络。\(Q_j\) 的计算依赖于输入样本 \(X_i\),因此推断出的权重是样本依赖的。
PLATO 的关键优势在于大幅减少可训练参数量:标准 MLP 在第一层有 \(d \times h\) 个需训练的参数;PLATO 通过共享网络 \(\mathcal{B}\),只需训练 \(\Phi\) 中的参数。假设 \(\mathcal{B}\) 为单层网络,则 \(|\Phi| = c \times h\)。由于特征嵌入维度 \(c\) 远小于特征数 \(d\),因此 \(|\Phi| = ch \ll dh\),实现了参数量的减少。

实验结果

数据集和实验设置

在 6 个 \(d\gg n\) 的表格数据集上进行评估(MNSCLC, CM, PDAC, BRCA, CRC, CH),特征数(\(d\))在 12k-20k 之间,样本数(\(n\))在 286-924 之间,\(d/n\) 比值在 19.7 到 52.2 之间。扩展测试了 4 个 \(d\sim n\) 的数据集上进行测试(ME, BC, SCLC, NSCLC),\(d/n\) 比值在 1.1 到 2.0 之间。所有数据集均与生物医学相关,如预测癌细胞系或肿瘤模型对药物的反应。输入特征为基因表达值,每个样本还有一个额外的、固定的药物特征向量(药物在 KG 中的嵌入)。使用一个整合了多源生物医学知识的大型异构图,包含 108,447 个节点、3,066,156 条边和 99 种关系类型,所有数据集的基因和药物节点均包含在内。
与 13 种最先进的统计和深度学习方法进行对比,包括以下类别:

算法类型 对比模型
经典统计机器学习 Ridge
降维 PCA + 线性回归
特征选择 LASSO, 随机门 (STG)
决策树 XGBoost
图正则化 GraphNet, NC LASSO, Network LASSO (使用 KG 的特征子图)
参数推断 Diet Networks
表格深度学习 标准 MLP, NODE, TabTransformer, TabNet (FT-Transformer 因内存问题无法运行)

评估使用表格学习基准的评价标准,对每个模型在每个数据集上进行 500 次随机超参数搜索。数据按 60/20/20 划分训练、验证、测试集,报告在 3 个数据划分和每个划分上 3 次模型运行的均值±标准差,评估指标为预测值与真实值之间的皮尔逊相关系数。

对比实验

PLATO 在 \(d\gg n\) 场景下性能优于基线,结果可见 PLATO 在全部 6 个$ d\gg n$ 数据集上均取得了最佳性能。例如在 PDAC 数据集上,PLATO 相比最佳基线 XGBoost 的 PearsonR 提升了10.19%(0.400 vs. 0.363)。基线的表现高度依赖于具体数据集,没有一种方法能在所有数据集上保持稳定领先。例如,TabTransformer 在 MNSCLC 上表现第二,但在 PDAC 和 CH 上表现最差。这凸显了\(d\gg n\)问题的挑战性,也验证了 PLATO 利用 KG 先验知识的鲁棒性优势。
image
在 4 个 \(d\sim n\) 数据集上,PLATO 的表现与最强的基线 XGBoost 竞争激烈,互有胜负,但未表现出如在 \(d\gg n\) 场景下的显著优势。当样本量相对充足(\(d\sim n\))时,从数据本身可能已能训练出强预测模型(如 XGBoost),此时外部辅助信息(KG)的边际收益减小,印证了 PLATO 的核心价值在于解决数据稀缺(\(d\gg n\))的难题。

消融实验

对于可训练消息传递函数的重要性,在 BRCA 数据集上测试了 PLATO 的三个变体:

PLATO 消息传递变体 说明
完整 PLATO 使用可训练消息传递函数 \(\mathcal{Q}\) 更新后的特征嵌入 \(Q\) 来推断权重
使用预训练嵌入 直接使用预训练的特征嵌入\(M\)(未经消息传递更新)来推断权重
无 KG(标准 MLP) 不使用任何 KG 信息,退化为标准 MLP

结果可见完整 PLATO (0.583) > 使用预训练嵌入 (0.522) > 标准 MLP (0.240),这表明引入 KG 先验知识(\(M\))能显著提升性能,通过监督信号在具体任务上更新这些嵌入(\(Q\))是进一步提升性能的关键。
image

对于 KG 中非特征节点(领域知识)的重要性,在 BRCA 数据集上测试了 KG 的完备性:

PLATO 领域知识变体 说明
完整 KG 包含特征节点和其他领域知识节点
仅特征 KG 只包含特征节点及其之间关系的子图
无 KG 标准 MLP

结果可见完整 KG (0.583) > 仅特征 KG (0.539) > 无 KG (0.240),这表明 KG 中超出特征直接关系的、更广泛的领域信息(如基因功能、生物过程等)对于模型性能有重要贡献。
image

对于不完整 KG 的鲁棒性,实验随机移除 KG 中不同比例的边,测试 PLATO 在 BRCA 上的性能。结果可见即使只保留 50% 的边,PLATO 仍能保持完整性能的 71%(0.412 vs. 0.583),这表明 PLATO 采用的 KG 嵌入和消息传递机制对 KG 中的缺失信息具有一定的鲁棒性。
image

对于 MLP 深层(可训练权重层)的重要性,实验比较了完整 PLATO(MLP,第一层权重推断,其余层训练)与 PLATO-LR(线性回归,所有权重推断)。结果可见完整 PLATO (0.583) > PLATO-LR (0.550),这说明在推断的第一层权重之上,增加可训练的非线性变换层(MLP 第 2 到 L 层)对于最终性能是必要的,推断权重本身并不足以捕获所有复杂模式。
image

补充结果

在参数量对比方面,结果可见 PLATO 通过权重推断,可训练参数量相比标准 MLP 显著减少(通常少一个数量级),缓解了 \(d\gg n\) 下的过拟合风险。
image
使用不同的 KG 嵌入方法(TransE, DistMult, ComplEx)预训练特征嵌入 \(M\),PLATO 的性能相近。表明方法对具体的 KG 嵌入模块选择不敏感,具有通用性。
image
在模型排名稳定性方面,PLATO 在 \(d\gg n\) 数据集上排名始终第一,而各基线的排名波动很大,进一步验证了 PLATO 的稳健性。
image

优点和创新点

个人认为,本文有如下一些优点和创新点可供参考学习:

  1. 本文为高维小样本(\(d\gg n\))表格数据提出了一种可结合异质辅助 KG 进行深度学习的新问题框架,其中每个输入特征都对应于知识图谱中的一个节点。
  2. 提出了 PLATO 模型,其核心创新在于通过一个可训练的、样本依赖的消息传递函数从 KG 中推断出 MLP 第一层的全部权重,从而将 KG 节点相似性的领域知识转化为模型参数相似性的归纳偏置,实现了高效的正则化。
posted @ 2026-04-10 11:42  乌漆WhiteMoon  阅读(12)  评论(0)    收藏  举报