ICLR-2024《P²OT: Progressive Partial Optimal Transport for Deep Imbalanced Clustering》 - 详解
核心思想
这篇论文针对深度聚类(deep clustering)在现实世界中常见的不平衡数据分布问题,提出了一种新的问题设定——深度不平衡聚类(deep imbalanced clustering)。传统深度聚类方法多假设素材均匀分布,忽略了长尾分布的挑战,导致伪标签生成质量低下和确认偏差(confirmation bias)。核心思想是通过伪标签(pseudo-labeling)框架,引入渐进式部分最优传输(Progressive Partial Optimal Transport, P²OT)算法,将伪标签生成表述为一个单一的凸优化问题。该问题同时考虑类不平衡分布(通过KL散度约束)和样本置信度(通过总质量约束渐进选择高置信样本),从而生成不平衡感知的伪标签,并逐步从易样本过渡到难样本学习,避免手动阈值调参和退化解。
目标函数
给定模型对NNN个样本的预测P∈R+N×KP \in \mathbb{R}^{N \times K}_+P∈R+N×K(KKK为簇数)和伪标签Q∈R+N×KQ \in \mathbb{R}^{N \times K}_+Q∈R+N×K,初始交叉熵损失为:
L=−∑i=1NQilogPi=⟨Q,−logP⟩F, \mathcal{L} = -\sum_{i=1}^N Q_i \log P_i = \langle Q, -\log P \rangle_F,L=−i=1∑NQilogPi=⟨Q,−logP⟩F,
其中⟨⋅,⋅⟩F\langle \cdot, \cdot \rangle_F⟨⋅,⋅⟩F为Frobenius内积。为避免退化解(所有样本赋给单一簇),引入KL散度约束假设先验均匀分布1K1K\frac{1}{K}1_KK11K,并考虑不平衡分布,将伪标签生成表述为不平衡OT问题:
minQ∈Π⟨Q,−logP⟩F+λKL(Q⊤1N,1K1K), \min_{Q \in \Pi} \langle Q, -\log P \rangle_F + \lambda \mathrm{KL}(Q^\top 1_N, \frac{1}{K} 1_K),Q∈Πmin⟨Q,−logP⟩F+λKL(Q⊤1N,K11K),
s.t. Π={Q∈R+N×K∣Q1K=1N1N}\Pi = \{ Q \in \mathbb{R}^{N \times K}_+ \mid Q 1_K = \frac{1}{N} 1_N \}Π={Q∈R+N×K∣Q1K=N11N}。
为处理初始表示差导致的噪声伪标签,引入课程学习思想:依据总质量约束ρ\rhoρ(渐增)选择高置信样本,并允许样本权重调整,避免阈值调参。最终P²OT目标函数为:
minQ∈Π⟨Q,−logP⟩F+λKL(Q⊤1N,ρK1K), \min_{Q \in \Pi} \langle Q, -\log P \rangle_F + \lambda \mathrm{KL}(Q^\top 1_N, \frac{\rho}{K} 1_K),Q∈Πmin⟨Q,−logP⟩F+λKL(Q⊤1N,Kρ1K),
s.t. Π={Q∈R+N×K∣Q1K≤1N1N, 1N⊤Q1K=ρ}\Pi = \{ Q \in \mathbb{R}^{N \times K}_+ \mid Q 1_K \leq \frac{1}{N} 1_N, \, 1_N^\top Q 1_K = \rho \}Π={Q∈R+N×K∣Q1K≤N11N,1N⊤Q1K=ρ},
其中ρ∈[0,1]\rho \in [0,1]ρ∈[0,1]为选中质量分数,渐增策略为ρ=ρ0+(1−ρ0)⋅e−5(1−t/T)2\rho = \rho_0 + (1 - \rho_0) \cdot e^{-5(1 - t/T)^2}ρ=ρ0+(1−ρ0)⋅e−5(1−t/T)2(ρ0=0.1\rho_0=0.1ρ0=0.1,ttt为当前迭代,TTT为总迭代),KL为非归一化散度以处理不平衡。
目标函数详细的优化过程
P²OT是非标准不平衡OT,为高效求解,重构为标准形式:
引入虚拟簇吸收未选质量:扩展QQQ为Q^=[Q,ξ]∈R+N×(K+1)\hat{Q} = [Q, \xi] \in \mathbb{R}^{N \times (K+1)}_+Q^=[Q,ξ]∈R+N×(K+1),其中ξ∈RN×1\xi \in \mathbb{R}^{N \times 1}ξ∈RN×1为虚拟簇分配,满足Q^1K+1=1N1N\hat{Q} 1_{K+1} = \frac{1}{N} 1_NQ^1K+1=N11N和1N⊤ξ=1−ρ1_N^\top \xi = 1 - \rho1N⊤ξ=1−ρ。成本矩阵C=[−logP,0N]C = [-\log P, 0_N]C=[−logP,0N],目标变为:
minQ^∈Φ⟨Q^,C⟩F+λKL(Q^⊤1N,β), \min_{\hat{Q} \in \Phi} \langle \hat{Q}, C \rangle_F + \lambda \mathrm{KL}(\hat{Q}^\top 1_N, \beta),Q^∈Φmin⟨Q^,C⟩F+λKL(Q^⊤1N,β),
s.t. Φ={Q^∈R+N×(K+1)∣Q^1K+1=1N1N}\Phi = \{ \hat{Q} \in \mathbb{R}^{N \times (K+1)}_+ \mid \hat{Q} 1_{K+1} = \frac{1}{N} 1_N \}Φ={Q^∈R+N×(K+1)∣Q^1K+1=N11N},β=(ρK1K1−ρ)\beta = \begin{pmatrix} \frac{\rho}{K} 1_K \\ 1 - \rho \end{pmatrix}β=(Kρ1K1−ρ)。替换为加权KL确保严格约束:原KL无法严格保证1N⊤ξ=1−ρ1_N^\top \xi = 1 - \rho1N⊤ξ=1−ρ,引入加权KL:
KL^(Q^⊤1N,β,λ)=∑i=1K+1λi[Q^⊤1N]ilog[Q^⊤1N]iβi, \hat{\mathrm{KL}}(\hat{Q}^\top 1_N, \beta, \lambda) = \sum_{i=1}^{K+1} \lambda_i [\hat{Q}^\top 1_N]_i \log \frac{[\hat{Q}^\top 1_N]_i}{\beta_i},KL^(Q^⊤1N,β,λ)=i=1∑K+1λi[Q^⊤1N]ilogβi[Q^⊤1N]i,
并设λK+1→+∞\lambda_{K+1} \to +\inftyλK+1→+∞(实际用大值ι\iotaι),以强制虚拟簇大小为1−ρ1 - \rho1−ρ。目标为:
minQ^∈Φ⟨Q^,C⟩F+KL^(Q^⊤1N,β,λ). \min_{\hat{Q} \in \Phi} \langle \hat{Q}, C \rangle_F + \hat{\mathrm{KL}}(\hat{Q}^\top 1_N, \beta, \lambda).Q^∈Φmin⟨Q^,C⟩F+KL^(Q^⊤1N,β,λ).理论保证:命题1证明Q^⋆=[Q⋆,ξ⋆]\hat{Q}^\star = [Q^\star, \xi^\star]Q^⋆=[Q⋆,ξ⋆],其中Q⋆Q^\starQ⋆为原P²OT最优解。
熵正则化与缩放求解:添加−ϵH(Q^)-\epsilon H(\hat{Q})−ϵH(Q^)(ϵ=0.1\epsilon=0.1ϵ=0.1),问题可由Sinkhorn缩放算法求解。令M=exp(−C/ϵ)M = \exp(-C / \epsilon)M=exp(−C/ϵ),f=λλ+ϵf = \frac{\lambda}{\lambda + \epsilon}f=λ+ϵλ(逐元素),α=1N1N\alpha = \frac{1}{N} 1_Nα=N11N:
Q^⋆=diag(a)Mdiag(b), \hat{Q}^\star = \mathrm{diag}(a) M \mathrm{diag}(b),Q^⋆=diag(a)Mdiag(b),
其中a,ba, ba,b由迭代更新:
a←αMb,b←(βM⊤a)∘f, a \leftarrow \frac{\alpha}{M b}, \quad b \leftarrow \left( \frac{\beta}{M^\top a} \right) \circ f,a←Mbα,b←(M⊤aβ)∘f,
直到bbb收敛(变化<1e-6或迭代达1000)。提取Q=Q^[:,:K]Q = \hat{Q}[:, :K]Q=Q^[:,:K]。证明见附录C。
该过程将总质量约束融入边际,KL松弛为不平衡,优于通用不平衡OT求解器(快2倍)。
主要贡献点
- 问题与基准:首次形式化深度不平衡聚类,建立新基准(长尾CIFAR100、ImageNet-R、iNaturalist子集),填补现实差距。
- 框架与算法:提出渐进PL框架,用P²OT统一建模不平衡分布与置信选择,避免退化与偏差。
- 优化创新:重构P²OT为不平衡OT(理论保证),用高效缩放算法求解,支持mini-batch与记忆缓冲稳定。
- 实验验证:在基准上SOTA(如CIFAR100 ACC提升0.9%,ImageNet-R提升2.4%),证明鲁棒性。
算法实现过程
算法实现为Algorithm 1(缩放算法):
- 输入:成本−logP-\log P−logP,ϵ,λ,ρ,N,K\epsilon, \lambda, \rho, N, Kϵ,λ,ρ,N,K,大值ι\iotaι。
- 预处理:
- C←[−logP,0N]C \leftarrow [-\log P, 0_N]C←[−logP,0N](扩展成本)。
- λ←[λ,…,λ,ι]⊤\lambda \leftarrow [\lambda, \dots, \lambda, \iota]^\topλ←[λ,…,λ,ι]⊤(K+1维,加权)。
- β←[ρK1K⊤,1−ρ]⊤\beta \leftarrow [\frac{\rho}{K} 1_K^\top, 1 - \rho]^\topβ←[Kρ1K⊤,1−ρ]⊤。
- α←1N1N\alpha \leftarrow \frac{1}{N} 1_Nα←N11N。
- 初始化b←1K+1b \leftarrow 1_{K+1}b←1K+1,M←exp(−C/ϵ)M \leftarrow \exp(-C / \epsilon)M←exp(−C/ϵ),f←λλ+ϵf \leftarrow \frac{\lambda}{\lambda + \epsilon}f←λ+ϵλ。
- 迭代:
- while bbb未收敛:
- a←αMba \leftarrow \frac{\alpha}{M b}a←Mbα。
- b←(βM⊤a)∘fb \leftarrow \left( \frac{\beta}{M^\top a} \right) \circ fb←(M⊤aβ)∘f。
- while bbb未收敛:
- 输出:Q←diag(a)Mdiag(b)[:,:K]Q \leftarrow \mathrm{diag}(a) M \mathrm{diag}(b)[:, :K]Q←diag(a)Mdiag(b)[:,:K](前K列为伪标签)。
实际中,用mini-batch处理大数据集,记忆缓冲存储历史预测稳定优化。整体框架交替:用当前表示计算PPP→P²OT得QQQ→用QQQ更新表示→重复至收敛。超参:λ=1\lambda=1λ=1,ϵ=0.1\epsilon=0.1ϵ=0.1,ρ0=0.1\rho_0=0.1ρ0=0.1。
论文局限性总结与分析
该论文在深度不平衡聚类上创新性强,但存在以下局限:
- 依赖预训练模型:实验基于DINO预训练ViT-B16,未探讨从零训练场景。现实中,若无高质量预训练,初始表示差可能放大确认偏差,影响P²OT收敛。
- ρ\rhoρ策略固定:渐增用sigmoid函数,依赖迭代数t/Tt/Tt/T,非自适应。论文承认未来需基于学习进度动态调整,否则早期ρ\rhoρ过小可能欠拟合,晚期过快引入噪声。
- 计算效率与规模:虽缩放算法高效(矩阵-向量乘),但对超大规模资料集(如全iNaturalist),mini-batch与缓冲增加内存开销。未比较与投影梯度下降等方法的权衡。
- 不平衡假设:假设先验均匀,仅用KL松弛处理长尾,未显式建模极端不平衡(如R>>100)或多模态分布。噪声标签或分布外样本可能导致虚拟簇滥用。
- 评估局限:基准虽新,但ImageNet-R“分布外”挑战依赖预训练分布;未测试开放集聚类或动态簇数KKK。指标(如ACC)对尾类敏感,但ARI不适不平衡(仅附录)。
- 理论深度不足:虽有命题保证等价性,但缺乏全局收敛或泛化界;熵ϵ\epsilonϵ调参敏感,可能影响不平衡捕捉。
总体,该工作推进PL在不平衡场景的应用,但未来可探索自监督初始化、自适应约束和理论强化,以提升通用性。
浙公网安备 33010602011771号