CAN:借助先验分布提升分类性能的简单后处理技巧
顾名思义,本文将会介绍一种用于分类问题的后处理技巧——CAN(Classification with Alternating Normalization),出自论文《When in Doubt: Improving Classification Performance with Alternating Normalization》。经过笔者的实测,CAN确实多数情况下能提升多分类问题的效果,而且几乎没有增加预测成本,因为它仅仅是对预测结果的简单重新归一化操作。
有趣的是,其实CAN的思想是非常朴素的,朴素到每个人在生活中都应该用过同样的思想。然而,CAN的论文却没有很好地说清楚这个思想,只是纯粹形式化地介绍和实验这个方法。本文的分享中,将会尽量将算法思想介绍清楚。
思想例子 #
假设有一个二分类问题,模型对于输入aa给出的预测结果是p(a)=[0.05,0.95]p(a)=[0.05,0.95],那么我们就可以给出预测类别为11;接下来,对于输入bb,模型给出的预测结果是p(b)=[0.5,0.5]p(b)=[0.5,0.5],这时候处于最不确定的状态,我们也不知道输出哪个类别好。
但是,假如我告诉你:1、类别必然是0或1其中之一;2、两个类别的出现概率各为0.5。在这两点先验信息之下,由于前一个样本预测结果为1,那么基于朴素的均匀思想,我们是否更倾向于将后一个样本预测为0,以得到一个满足第二点先验的预测结果?
这样的例子还有很多,比如做10道选择题,前9道你都比较有信心,第10题完全不会只能瞎蒙,然后你一看发现前9题选A、B、C的都有就是没有一个选D的,那么第10题在蒙的时候你会不会更倾向于选D?
这些简单例子的背后,有着跟CAN同样的思想,它其实就是用先验分布来校正低置信度的预测结果,使得新的预测结果的分布更接近先验分布。
不确定性 #
准确来说,CAN是针对低置信度预测结果的后处理手段,所以我们首先要有一个衡量预测结果不确定性的指标。常见的度量是“熵”,对于p=[p1,p2,⋯,pm]p=[p1,p2,⋯,pm],定义为:
然而,虽然熵是一个常见选择,但其实它得出的结果并不总是符合我们的直观理解。比如对于p(a)=[0.5,0.25,0.25]p(a)=[0.5,0.25,0.25]和p(b)=[0.5,0.5,0]p(b)=[0.5,0.5,0],直接套用公式得到H(p(a))>H(p(b))H(p(a))>H(p(b)),但就我们的分类场景而言,显然我们会认为p(b)p(b)比p(a)p(a)更不确定,所以直接用熵还不够合理。
一个简单的修正是只用前top-kk个概率值来算熵,不失一般性,假设p1,p2,⋯,pkp1,p2,⋯,pk是概率最高的kk个值,那么
其中p~i=pi/∑i=1kpip~i=pi/∑i=1kpi。为了得到一个0~1范围内的结果,我们取Htop-k(p)/logkHtop-k(p)/logk为最终的不确定性指标。
算法步骤 #
现在假设我们有NN个样本需要预测类别,模型直接的预测结果是NN个概率分布p(1),p(2),⋯,p(N)p(1),p(2),⋯,p(N),假设测试样本和训练样本是同分布的,那么完美的预测结果应该有:
其中p~p~是类别的先验分布,我们可以直接从训练集估计。也就是说,全体预测结果应该跟先验分布是一致的,但受限于模型性能等原因,实际的预测结果可能明显偏离上式,这时候我们就可以人为修正这部分。
具体来说,我们选定一个阈值ττ,将指标小于ττ的预测结果视为高置信度的,而大于等于ττ的则是低置信度的,不失一般性,我们假设前nn个结果p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)属于高置信度的,而剩下的N−nN−n个属于低置信度的。我们认为高置信度部分是更加可靠的,所以它们不用修正,并且可以用它们来作为“标准参考系”来修正低置信度部分。
具体来说,对于∀j∈{n+1,n+2,⋯,N}∀j∈{n+1,n+2,⋯,N},我们将p(j)p(j)与高置信度的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)一起,执行一次“行间”标准化:
这里的k∈{1,2,⋯,n}∪{j}k∈{1,2,⋯,n}∪{j},其中乘除法都是element-wise的。不难发现,这个标准化的目的是使得所有新的p(k)p(k)的平均向量等于先验分布p~p~,也就是促使式(3)(3)的成立。然而,这样标准化之后,每个p(k)p(k)就未必满足归一化了,所以我们还要执行一次“行内”标准化:
但这样一来,式(3)(3)可能又不成立了。所以理论上我们可以交替迭代执行这两步,直到结果收敛(不过实验结果显示一般情况下一次的效果是最好的)。最后,我们只保留最新的p(j)p(j)作为原来第jj个样本的预测结果,其余的p(k)p(k)均弃之不用。
注意,这个过程需要我们遍历每个低置信度结果j∈{n+1,n+2,⋯,N}j∈{n+1,n+2,⋯,N}执行,也就是说是逐个样本进行修正,而不是一次性修正的,每个p(j)p(j)都借助原始的高置信度结果p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)组合来按照上述步骤迭代,虽然迭代过程中对应的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)都会随之更新,但那只是临时结果,最后都是弃之不用的,每次修正都是用原始的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)。
参考实现 #
这是笔者给出的参考实现代码:
# 预测结果,计算修正前准确率
y_pred = model.predict(
valid_generator.fortest(), steps=len(valid_generator), verbose=True
)
y_true = np.array([d[1] for d in valid_data])
acc_original = np.mean([y_pred.argmax(1) == y_true])
print('original acc: %s' % acc_original)
# 评价每个预测结果的不确定性
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty = -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)
# 选择阈值,划分高、低置信度两部分
threshold = 0.9
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]
# 显示两部分各自的准确率
# 一般而言,高置信度集准确率会远高于低置信度的
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean()
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean()
print('confident acc: %s' % acc_confident)
print('unconfident acc: %s' % acc_unconfident)
# 从训练集统计先验分布
prior = np.zeros(num_classes)
for d in train_data:
prior[d[1]] += 1.
prior /= prior.sum()
# 逐个修改低置信度样本,并重新评价准确率
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):
Y = np.concatenate([y_pred_confident, y[None]], axis=0)
for j in range(iters):
Y = Y**alpha
Y /= Y.mean(axis=0, keepdims=True)
Y *= prior[None]
Y /= Y.sum(axis=1, keepdims=True)
y = Y[-1]
if y.argmax() == y_true_unconfident[i]:
right += 1
# 输出修正后的准确率
acc_final = (acc_confident * len(y_pred_confident) + right) / len(y_pred)
print('new unconfident acc: %s' % (right / (i
