Paper Reading: TabNet: Attentive Interpretable Tabular Learning
Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。
| 论文概况 | 详细 |
|---|---|
| 标题 | 《TabNet: Attentive Interpretable Tabular Learning》 |
| 作者 | Sercan O. Arık, Tomas Pfister |
| 发表会议 | The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21) |
| 会议年份 | 2021 |
| 论文代码 | https://github.com/google-research/google-research/tree/master/tabnet |
作者单位:
Google Cloud AI Sunnyvale, CA
研究动机
深度学习在图像、文本和音频等领域取得了显著成功,这很大程度上得益于标准架构,这些架构能高效地将原始数据编码为有意义的表征。然而,现实世界 AI 应用中最常见的表格数据却尚未出现这样一个获得广泛成功的标准深度学习架构,目前绝大多数表格数据学习任务仍由基于集成决策树的变体所主导。基于决策树的算法在表格数据上表现出色,主要得益于以下几点:
- 表征效率:对于表格数据中常见的、近似超平面边界的决策流形,决策树具有很高的表征效率;
- 可解释性:决策树的基本形式本身易于解释,而其集成形式也有流行的事后解释方法(如 SHAP);
- 训练速度快:与复杂的深度学习模型相比,决策树通常训练速度更快。
先前提出的深度学习架构并不适合表格数据,它们通常严重过参数化,并且缺乏适合表格数据的归纳偏置,这常常导致它们无法为表格决策流形找到最优解。尽管面临挑战,但是深度学习的潜力远超当前主导方法。对于大规模数据集,深度学习模型有望带来显著的性能提升,这符合模型性能随数据量增加而提升的普遍规律。与树学习相比,深度学习架构具有多重不可替代的优势:
- 多模态数据高效编码:可以整合并高效编码图像等多种数据类型与表格数据;
- 减少特征工程:可以减轻目前基于树的方法中对特征工程的高度依赖;
- 适应流式数据:更易于从数据流中持续学习;
- 支持表征学习:端到端模型允许进行表征学习,从而支持许多有价值的应用场景。
文章贡献
本文提出了一种用于表格数据深度学习的架构 TabNet,该模型的核心创新在于模仿决策树的特征选择能力,通过一种序列注意力机制(sequential attention) 来实现实例级的软特征选择。在每一步决策中,TabNet 都会动态地、稀疏地选择最相关的特征子集进行推理,从而将模型的学习能力集中在最显著的特征上。这不仅提高了学习效率,减少了冗余参数,还自然地为模型提供了内在的可解释性。其编码器由多个决策步骤组成,每个步骤包含一个用于特征选择的注意力变换器(Attentive Transformer) 和一个用于特征处理的特征变换器(Feature Transformer)。此外,TabNet 首次为表格数据引入了掩码自监督学习框架,通过预测被掩码的特征来进行预训练,从而能够有效利用大量未标注数据来提升模型在下游任务中的性能。通过广泛的实验验证,证明了TabNet 在多个不同领域的分类和回归数据集上达到或超越了当前主流表格学习模型的性能,同时提供了局部和全局两个层面的可解释性。

预备知识
Sparsemax
Sparsemax 是一个激活函数,是传统 Softmax 函数的一种替代方案。与 Softmax 总是产生一个稠密的概率分布(所有输出都大于零)不同,Sparsemax 能够产生一个稀疏的概率分布,即对于某些输入 Sparsemax 的输出会精确地为 0。
可以将 Sparsemax 理解为一个在概率单纯形上的欧几里得投影,目标为给定一个分数向量 $ \mathbf{z} = [z_1, z_2, ..., z_K] $,我们希望将其转换为一个概率分布 \(\mathbf{p} = [p_1, p_2, ..., p_K]\),其中 \(\sum_{i=1}^K p_i = 1\) 且 \(p_i \ge 0\)。Softmax 的做法是通过指数函数进行非线性变换,这保证了所有输出都大于零,但永远不会等于零。Sparsemax 的做法是寻找一个概率分布 \(\mathbf{p}\),使得 \(\mathbf{p}\) 与原始分数 \(\mathbf{z}\) 的欧几里得距离最小,同时满足概率分布的约束条件。这个“投影”操作将一部分较低的分数直接“截断”为 0,而剩余的概率质量则均匀地分配给那些“被选中”的维度。Sparsemax 函数的定义如下,其中 \(\Delta^{K-1}\) 是 \(K-1\) 维的概率单纯形,即满足 \(\sum_{i=1}^K p_i = 1\) 且 \(p_i \ge 0\) 的所有点的集合。
Sparsemax 的输出可以按以下步骤计算:
- 排序:将输入向量 \(\mathbf{z}\) 按从大到小的顺序排序。令 \(z_{(1)} \ge z_{(2)} \ge \cdots \ge z_{(K)}\) 表示排序后的值。
- 寻找支持集:找到支持集的大小 \(\kappa(\mathbf{z})\),即输出中非零元素的最大个数。它被定义为满足下式的最大 \(k\):\[\kappa(\mathbf{z}) = \max \left\{ k \in [1, K] \ \middle|\ 1 + k z_{(k)} > \sum_{j \le k} z_{(j)} \right\} \]一个更常见的等价计算方法是找到阈值函数 \(\tau(\mathbf{z})\),即需要找到最大的 \(k\),使得 \(z_{(k)} > \tau(\mathbf{z})\),这个 \(\tau\) 就是最终的阈值。\[\tau(\mathbf{z}) = \frac{\left(\sum_{j=1}^k z_{(j)}\right) - 1}{k} \]
- 计算阈值:根据找到的 \(k\) 计算阈值:\[\tau(\mathbf{z}) = \frac{\left(\sum_{j=1}^{\kappa(\mathbf{z})} z_{(j)}\right) - 1}{\kappa(\mathbf{z})} \]
- 输出结果:Sparsemax 的最终输出是输入分数减去阈值 \(\tau\),然后进行裁剪(ReLU 操作):\[\text{Sparsemax}(\mathbf{z})_i = \max(z_i - \tau(\mathbf{z}), 0) \]
给出一个简单的例子,假设有一个输入向量 \(\mathbf{z} = [1.0, 0.8, 0.2, -0.5]\):
- 排序:\(z_{(1)} = 1.0, \ z_{(2)} = 0.8, \ z_{(3)} = 0.2, \ z_{(4)} = -0.5\)。
- 寻找 \(k\):
- 测试 \(k=1\): \(\tau = (1.0 - 1)/1 = 0\)。\(z_{(1)}=1.0 > 0\),成立。
- 测试 \(k=2\): \(\tau = ((1.0+0.8) - 1)/2 = 0.8/2 = 0.4\)。\(z_{(2)}=0.8 > 0.4\),成立。
- 测试 \(k=3\): \(\tau = ((1.0+0.8+0.2) - 1)/3 = (2.0 - 1)/3 \approx 0.333\)。\(z_{(3)}=0.2 < 0.333\),不成立。
- 所以 \(\kappa(\mathbf{z}) = 2\)。
- 计算阈值:\(\tau = 0.4\)(由上一步 \(k=2\) 时已算出)。
- 输出:对每个 \(z_i\) 减去阈值 0.4 并裁剪。
- \(p_1 = \max(1.0 - 0.4, 0) = 0.6\)
- \(p_2 = \max(0.8 - 0.4, 0) = 0.4\)
- \(p_3 = \max(0.2 - 0.4, 0) = 0\)
- \(p_4 = \max(-0.5 - 0.4, 0) = 0\)
最终,\(\text{Sparsemax}(\mathbf{z}) = [0.6, 0.4, 0.0, 0.0]\)。可以看到,两个较小的分数被精确地置为了零,结果是一个稀疏分布。
门控线性单元
GLU(Gated Linear Unit,门控线性单元) 函数的核心思想来源于门控机制,类似于 LSTM 或 GRU 中的门控单元。它的目的是通过一个“门”来控制信息的流动,让模型能够学会在网络的每一层中保留哪些信息以及舍弃哪些信息。GLU 的操作是给定一个输入张量 \(X\),通常张量的最后一个维度(即特征维度)是偶数。首先将其均匀地分割为两部分 $ A, B = \text{split}(X, \text{axis}=-1) $,它们的形状完全相同。然后,GLU 的计算如下:
其符号的含义为:
| 符号 | 含义 |
|---|---|
| $ A $ | “信息”部分 |
| $ B $ | “门”部分 |
| $ \sigma $ | Sigmoid 函数,它将 $ B $ 中的每个元素的值压缩到 (0, 1) 区间 |
| $ \otimes $ | 逐元素乘法 |
Sigmoid 函数产生的门控值 \(\sigma(B)\) 就像一个“水龙头”或“阀门”,模型通过训练来学习如何生成最合适的门控信号 $ B $,以优化最终任务。
- 当门控值接近 1 时,对应位置的信息 \(A\) 几乎被完全保留。
- 当门控值接近 0 时,对应位置的信息 \(A\) 几乎被完全屏蔽。
一个简单的计算示例如下,假设我们有一个特征维度为 4 的输入向量 \(X = [1, 2, 3, 4]\),首先分割将其均匀分割为信息部分 $ A = [1, 2] $ 和门控部分 $ B = [3, 4] $。接着计算门控信号,将 $ B $ 输入 Sigmoid 函数得 \(\sigma(B) \approx [0.9526, 0.9820]\)。然后进行逐元素相乘:$ \text{GLU}(X) = A \otimes \sigma(B) = [1 * 0.9526, 2 * 0.9820] \approx [0.9526, 1.9640] $,可以看到,原始信息 $ A = [1, 2] $ 被门控信号缩放为了 $ [0.9526, 1.9640] $。
原始的 GLU 被进行了多种改进,产生了不同的激活函数,它们的主要区别在于用于生成门控信号的激活函数不同。它们的通用公式如下,其中 $ g $ 是某种激活函数。
常见的变体包括:
| 变体名称 | 特点 |
|---|---|
| 原始 GLU | 基础形式 |
| ReGLU | 使用 ReLU 作为门控,计算简单,且在大型模型中表现优异(如 PaLM 论文中使用)。 |
| GEGLU | 目前最流行和效果最好的变体之一。GELU 是高斯误差线性单元,它是 ReLU 的平滑版本,被证明在 Transformer 模型中效果非常好。 |
| SwiGLU | Swish 是另一个平滑且表现优异的激活函数,SwiGLU 在多项基准测试中表现突出。 |
GLU 的优势与特点有:
- 缓解梯度消失:门控机制为梯度流动提供了一条“高速公路”(类似于残差连接),使得梯度可以更有效地反向传播,从而允许构建更深的网络。
- 自适应学习:模型可以自适应地学习为每个特征维度分配不同的权重(重要性),而不是像传统激活函数(如 ReLU)那样进行固定的非线性变换。
- 提升模型表现:在实践中,尤其是在自然语言处理领域的 Transformer 模型中(如作为前馈神经网络 FFN 的激活函数),GEGLU 和 SwiGLU 等变体通常比传统的 ReLU 或 GELU 激活函数表现更好。
- 参数效率:虽然 GLU 需要将特征维度翻倍来产生 A 和 B(因此输入维度会变大),但许多研究发现,为了达到相同的性能,使用 GLU 的模型往往可以比使用标准激活函数的模型更小。
本文方法
TabNet 的设计灵感来源于决策树,通过特定的结构设计,传统的深度神经网络(DNN)模块可以模拟决策树的输出流形,如下图所示。其中稀疏的实例级特征选择是实现超平面形式决策边界的关键,TabNet 的核心目标是在保留决策树优势(如特征选择能力)的同时,通过深度学习提升模型性能。

TabNet 编码器架构
TabNet 编码器采用多步骤序列处理结构,如下图所示。每个决策步骤逐步处理输入特征并聚合信息,其核心组件包括特征选择机制和特征处理模块。

注意力变换器
特征选择机制通过注意力变换器(Attentive Transformer) 接收前一步处理的信息 \(a[i-1]\) 生成一个稀疏掩码 \(M[i] \in \Re^{B\times D}\),实现对特征的软选择(\(M[i] \cdot f\))。其结构如下图所示,注意力变换器包括全连接层、批归一化层和 Sparsemax 归一化函数。

各个组件的作用如下表所示:
| 注意力变换器组件 | 作用 | 功能 |
|---|---|---|
| 全连接层 | 学习并转换特征表示 | 将输入的特征信息进行非线性变换,并将其映射到与原始输入特征维度相同的空间。可以将其理解为一个可学习的评分器,为每个特征生成一个未经过规范化的原始分数,表示该特征在当前决策步骤中的潜在重要性。 |
| 批归一化层 | 稳定内部激活值分布,加速并稳定训练 | 深度神经网络的各层输入数据的分布会随着训练而发生变化(内部协变量偏移),这会导致训练变得困难。批归一化层对全连接层输出的初步重要性得分进行归一化处理,使其均值为 0,方差为 1。通过稳定数值分布,允许使用更大的学习率,从而加快模型的训练速度。同时减轻了对参数初始化的敏感性,使训练过程更加平滑和稳定。 |
| Sparsemax 归一化函数 | 实现稀疏特征选择 | 将重要性得分转化为稀疏的、专注于少数关键特征的掩码。 |
该步骤的公式如下,其中 \(P[i]\) 是先前特征使用程度的先验尺度项,定义为 \(P[i] = \prod_{j=1}^{i}(\gamma - M[j])\),\(\gamma\) 为松弛参数。
Sparsemax 归一化确保了掩码的稀疏性。接着引入稀疏正则化损失 \(L_{sparse}\) 以控制特征选择的稀疏性,其公式如下所示:
特征变换器
特征处理模块使用特征变换器(Feature Transformer) 处理被选中的特征,其结构包含共享层(跨步骤参数复用)和决策步骤依赖层,每层由全连接层、批归一化(BN)和门控线性单元(GLU)非线性激活组成,如下图所示。

数据流为:输入 → 全连接层(线性变换) → 批归一化层(稳定分布) → GLU(非线性门控) → 与输入进行归一化残差相加(信息融合与稳定),各个组件的作用如下表所示:
| 特征变换器组件 | 作用 | 功能 |
|---|---|---|
| 全连接层 | 实现特征的线性组合与交互 | 通过一个权重矩阵将输入向量映射到新的特征空间,使得模型能够学习到输入特征之间复杂的线性关系。 |
| 批归一化层 | 稳定训练并加速收敛 | 对全连接层的输出进行标准化处理(使其均值为 0,方差为 1) |
| 门控线性单元 | 引入可控的非线性 | 提供一种高效且可控的非线性激活机制,模拟信息门控 |
| 归一化残差连接 | 避免网络退化,确保梯度有效传播 | 模块内部采用了残差连接,并将残差路径的输出乘以一个缩放因子。通过残差连接,确保了即使深层网络发生退化,底层的信息也能直接传递到后方,保证了模型的基准性能。缩放操作有助于确保网络中各层的方差不会发生剧烈变化,从而进一步稳定训练过程。 |
如果每个步骤都使用完全独立、不共享参数的特征变换器,会导致模型参数量急剧增加,容易过拟合,并且训练效率低下。反之,如果所有步骤都强制共享同一个变换器,模型又可能缺乏足够的灵活性来为每个步骤学习独特的特征表示。因此,特征变换器由两个共享层和两个决策步骤依赖层组成,其核心动机是:在参数效率 和表示灵活性之间取得最佳平衡。两部分的具体作用如下:
| 层次 | 作用 | 功能 |
|---|---|---|
| 共享层 | 学习通用的、与决策步骤无关的特征基础表示 | 这些层的参数在所有决策步骤之间是共享的。由于每个步骤处理的是相同的原始特征集,共享层可以学习如何对这些特征进行一种“通用”的、基础的非线性编码。这种参数复用提高了模型的参数效率,避免了不必要的重复学习,使模型更加紧凑,并有助于减少过拟合的风险。 |
| 决策步骤依赖层 | 学习特定于当前决策步骤的、专门化的特征表示 | 这些层的参数是每个决策步骤独有的。由于 TabNet 的每个步骤通过注意力机制选择了不同的特征子集,每个步骤的“任务焦点”是不同的。在共享层完成了基础特征提取之后,步骤依赖层可以根据当前步骤的特定任务,对特征进行进一步的、专门化的加工,实现针对当前步骤所关注的特征子集,学习最有效的深层表示。 |
所有决策步骤(Step 1, Step 2, ..., Step N)中的这 2 个共享层共享同一套参数,每个决策步骤的步骤依赖层 1 和步骤依赖层 2 都拥有自己独有的一套参数。这两部分以串联方式协同作用,数据流如下:输入特征 → 共享层1 → 共享层2 → 步骤依赖层1 → 步骤依赖层2 → 输出 \([d[i], a[i]]\)。模块输出的两部分含义为:当前步骤的决策贡献 \(d[i]\) 和传递给下一步的信息 \(a[i]\),即 \([d[i], a[i]] = f_i(M[i] \cdot f)\)。
TabNet 编码器数据流
TabNet 编码器的预测输出流程是一个从多步骤决策贡献聚合到线性映射的过程,核心在于将每个决策步骤的学习成果合并,最终通过一个简单的输出层得到预测值。TabNet 编码器的数据流如下:
- 使用注意力变换器生成掩码:输入来自前一个决策步骤的处理后信息 \(a[i-1]\),对于第一个步骤有特定的初始化方式。该信息通过一个可学习函数 \(h_i\),由全连接层和批归一化层实现。输出结果与先验尺度项 \(P[i-1]\) 相乘,\(P[i-1]\) 记录了每个特征在之前步骤中被使用的累积情况。这一机制鼓励模型在后续步骤中关注之前使用较少的特征,促进探索。经过调制后的结果通过 Sparsemax 归一化函数,生成一个稀疏的、实例特定的特征选择掩码 \(M[i]\) 输出。
- 应用掩码进行特征过滤:将上一步生成的掩码 \(M[i]\) 与原始的输入特征 \(f\) 进行逐元素相乘(Hadamard 积),得到过滤后的特征 \(M[i] · f\)。在此过程中,被掩码忽略的特征(对应
M[i]值为0)将不参与当前步骤的后续计算,从而确保模型容量集中于最显著的特征上。 - 使用特征变换器进行非线性变换:输入过滤后的特征 \(M[i] · f\) 给特征变换器 \(f_i\),通过共享层(所有决策步骤参数共享)学习通用的特征变换,再通过步骤依赖层(每个步骤参数独有)学习针对当前步骤的特定表示。特征变换器输出一个被深度编码的特征表示,并分割成两个部分:当前步骤对最终决策的贡献 \(d[i] ∈ ℜ^(B×N_d)\)、传递给下一步骤注意力变换器的信息 \(a[i] ∈ ℜ^(B×N_a)\)。
- 信息聚合与传递:当前步骤的决策贡献 \(d[i]\) 将被暂存,所有步骤的 \(d[i]\) 最终会通过聚合形成总体决策嵌入。信息 \(a[i]\) 被直接送入下一个决策步骤(第 \(i+1\) 步)的注意力变换器作为其输入,开始新一轮的特征选择与推理循环。
- 是更新先验尺度:在完成当前步骤后根据新生成的掩码 \(M[i]\) 更新先验尺度 \(P[i]\),为下一个步骤的特征选择做好准备。
最后 TabNet 通过“多步骤特征选择与推理 → ReLU 激活的决策贡献聚合 → 线性层映射 → Softmax/Argmax 输出”得到预测结果。TabNet 采用了一种线性求和的方式,并引入非线性激活函数来保证稳定性,将所有 \(N_{steps}\) 个步骤的决策贡献合并为一个总体表示:
得到一个融合了所有步骤信息的总体决策嵌入向量\(d_{out} \in \Re^{B \times N_d}\) 后,TabNet 使用一个线性映射层(即一个全连接层)来生成最终的预测结果:\(\text{Output} = W_{final} \cdot d_{out}\)。
可解释性设计
TabNet 通过以下机制提供局部和全局可解释性:
- 局部解释:每一步的掩码 \(M[i]\) 显示该步骤所选特征。
- 全局特征重要性:通过聚合各步骤的掩码权重,计算整体特征重要性 \(M_{agg}\):\[M_{agg-b,j} = \sum_{i=1}^{N_{steps}} \eta_b[i] M_{b,j}[i] / \sum_{j=1}^{D} \sum_{i=1}^{N_{steps}} \eta_b[i] M_{b,j}[i] \]其中 \(\eta_b[i] = \sum_{c=1}^{N_d} \text{ReLU}(d_{b,c}[i])\) 表示第 \(i\) 步对决策的贡献度。
TabNet 解码器与自监督学习
TabNet 的解码器主要用于实现自监督学习任务,其核心目标是从编码器生成的表示中重建出原始的表格特征。解码器将编码过程中分散在各个决策步骤的信息重新整合和上采样,以恢复原始输入。其输入是编码器的输出表示,即经过多个决策步骤处理并聚合后的信息。它的最终输出是重建的特征向量,其维度与原始输入特征相同。解码器由一系列特征变换器 和全连接层构成,结构如下图所示:

其工作流程可以概括为以下几个步骤:
- 逐步骤的特征变换:使用特征变换器块实现。解码器为每个决策步骤(从第 1 步到第 N_step 步)都配备了一个独立的特征变换器,这些变换器在结构上与编码器中使用的特征变换器类似。编码器每个步骤输出的中间表示会被分别送入解码器对应步骤的特征变换器中进行处理,实现对编码后的信息进行初步的逆变换和增强,为重建特征做准备。
- 使用全连接层进行特征维度映射:在每个决策步骤中,经过特征变换器处理后的数据会通过一个全连接层,主要功能是将高维的、抽象的编码表示映射回原始特征的维度(D 维),将深度特征空间转换回原始的表格特征空间。
- 步骤输出的聚合:解码器将所有决策步骤经过全连接层重建出的特征向量进行逐元素相加,得到一个最终的重建特征向量。这种求和操作基于一个假设:编码器的每个决策步骤都学习并编码了输入数据的不同方面。因此在重建时,需要将所有步骤所贡献的信息重新聚合起来,才能更完整地恢复原始输入。
TabNet 解码器的结构虽然比编码器简单,但其实现了从高层表示到原始特征空间的重建映射。它采用多步骤并行变换再聚合的方式,镜像了编码器的学习过程。

其具体流程如下:
- 输入掩码:一个二进制掩码 \(S \in \{0,1\}^{B\times D}\) 被应用于原始特征 \(f\),生成一个部分被掩盖的输入 \((1-S) \cdot \hat{f}\)。其中,值为 0 的位置表示该特征值被掩盖(未知),需要模型预测。
- 编码过程:编码器接收被掩盖的输入 \((1-S) \cdot \hat{f}\)。为了引导编码器只关注已知特征,先验尺度P[0]被初始化为 \((1-S)\)。这意味着被掩盖的特征在第一步就被标记为“已使用”,从而被模型忽略。
- 解码与重建:编码器产生的表示被送入解码器。解码器输出重建的所有特征。
- 损失计算:损失函数仅计算在被掩盖的特征(即S矩阵中值为1的位置)上的重建误差。文档中采用的损失是经过标准差归一化后的均方误差,这样做的好处是让不同尺度的特征对损失的贡献相对均衡。
实验结果
实例级特征选择的有效性
实验首先在 6 个合成数据集(Syn1-Syn6)上进行,这些数据集被设计为只有特征的一个子集决定输出结果,其中 Syn1-Syn3 是全局重要特征,Syn4-Syn6 是实例依赖的重要特征。实验结果如下表所示,TabNet 的性能优于或与其他特征选择方法(如 L2X, INVASE)相当。在全局重要特征数据集(Syn1-Syn3)上,TabNet 的性能接近 Global 方法。在实例依赖特征数据集(Syn4-Syn6)上,TabNet 通过消除实例级冗余特征,性能超过了“Global”方法。TabNet 的参数量(26k-31k)远少于 INVASE 等需要多个模型的方案(101k),体现了其参数效率。

真实数据集上的性能
实验在多个真实世界数据集上进行了测试,包括分类和回归任务。Forest Cover Type(森林覆盖类型分类)任务需要根据制图变量分类森林覆盖类型,结果可见 TabNet(96.99% 准确率)显著优于 XGBoost、LightGBM 等梯度提升树模型,甚至超过了经过自动化超参数搜索的 AutoML Tables 框架(94.95%)。

Poker Hand(扑克手牌分类)任务需要根据扑克牌的花色和等级分类手牌,这是一个具有确定性规则但数据高度不平衡的任务。实验结果可见传统 MLP、DT 及其混合模型表现不佳,梯度提升树略有提升但准确率仍低(约 71%)。TabNet 取得了接近规则方法的优异性能(99.2%),体现了其处理复杂非线性关系的能力。

机器人逆动力学回归的任务是回归拟人机器人手臂的逆动力学,实验结果为在模型大小受限时(TabNet-S,6.3K 参数),TabNet 与参数量大 100 倍的最佳模型性能相当。当不限制模型大小时(TabNet-L,1.75M 参数),TabNet 的测试 MSE(0.14)比现有最佳模型低一个数量级。

Higgs Boson(希格斯玻色子分类)的目的是区分希格斯玻色子信号与背景噪声,实验结果为在大规模数据集(1050 万实例)上 TabNet 的性能优于 MLP。且 TabNet 与先进的稀疏进化 MLP 性能相当,但 TabNet 的结构化稀疏更利于计算效率。

Rossmann Store Sales(零售销售额预测)需要根据静态和时序特征预测商店销售额,实验结果为 TabNet(MSE: 485.12)超越了所有对比的梯度提升树方法。

可解释性分析
实验通过可视化特征重要性掩码来展示 TabNet 的可解释性。合成数据的可视化效果下图所示,在 Syn2 上 TabNet 准确地将注意力集中在真正相关的特征(X3-X6)上,无关特征的重要性几乎为零。在 Syn4 上 TabNet 能根据指示特征(X11)动态选择不同的特征组(X1-X2 或 X3-X6)。

对于真实数据,在成人人口普查收入预测中,TabNet 给出的特征重要性排名(如“Age”最重要)与领域共识一致,并通过 t-SNE 可视化展示了“Age”特征对决策空间的清晰划分。

自监督学习
TabNet 采用掩码特征预测任务进行无监督预训练,然后用有标签数据对模型进行微调。在 Higgs 数据集上,随着有标签数据量的减少,预训练带来的提升越明显。当有标签数据为 1k 时,预训练将准确率从 57.47% 提升至 61.37%。即使有标签数据增至 100k,仍能观察到性能提升。

如下图所示,自监督预训练不仅提升了最终性能,还大幅加快了模型收敛速度,这对于持续学习和领域自适应非常有益。

优点和创新点
个人认为,本文有如下一些优点和创新点可供参考学习:
- TabNet 创新性地采用了序列注意力机制,实现了实例级的软特征选择,使模型在每一步决策中都能动态、稀疏地聚焦于当前最相关的特征子集;
- 该模型成功统一了高性能与内在可解释性,通过可视化每一步的特征选择掩码,既能提供局部实例的解释,也能聚合得到全局特征重要性;
- 本文将掩码自监督学习框架引入表格数据,通过预测被掩码的特征进行预训练,能够有效利用无标签数据来提升下游监督任务的性能。

浙公网安备 33010602011771号