快速了解 变分自编码器 VAE

概述

变分自编码器(Variational Auto-Encoders,VAE)是自编码器 AE 的变体,由 Kingma 等人于 2014 年提出的生成式网络结构。以概率的方式描述潜在空间,在数据生成方面潜力巨大。

自编码器 AE

自编码器(Auto-Encoder,AE),是一种无监督式学习模型。它可以将输入 \(X\) 映射为数据量小得多的潜在表示 \(h\),并能通过 \(h\) 尝试还原输入 \(X'\)

AE 包含两个部分:

  • Encoder 编码器,将输入 \(X\) 编码为潜在表示 \(h\)
  • Decoder 解码器,利用 \(h\) 重构输入 \(X'\)

AE 有着诸多好处。潜在表示 \(h\) 可以视为输入的重要特征,可以进行数据降维与压缩;在解码器重建数据时可以去除数据噪声,提高模型对噪声输入的鲁棒性;比起 CNN,训练不需要使用带标签的图像(无监督训练)。

不过要尤其注意过拟合问题。例如,AE 完全可以只用 \(h\) 中的一个数字 “死记硬背” 训练集中的每张图片 \(X\),这显然不是我们所期望的结果。

变分自编码器 VAE

VAE 对 AE 做了两个改动。

VAE 让编码器能够输出均值和方差,在推理阶段则从这样的正态分布里采样一个数据,作为解码器的输入。直观上看,这一改动就是在 AE 的基础上,让编码器多输出了一个方差,使原 AE 编码器的输出发生了一点随机扰动。

AE 的训练目标是,解码器输出尽可能与编码器输入相似。VAE 在此基础上增加了一项训练目标:让编码器输出尽可能贴近标准正态分布。

作为结果,VAE 的解码器被强迫从标准正态分布重建 \(X'\),有效解决了过拟合问题。并且由于标准正态分布可以不依靠编码器随机生成,VAE 还适合用于凭空生成新图像。

除了 VAE ,DDPM(Denoising Diffusion Probabilistic Model)也能处理过拟合问题,并且效果更好。

VAE 的损失函数

VAE 的 loss 函数包含两个部分:重构损失(Reconstruct Loss)和 KL 散度。

\[\text{loss}=\text{MSE}(X,X')+\text{KL}(N(\mu,\sigma^2),N(0,1)) \]

Reconstruct Loss 是解码器输出 \(X'\)编码器输入 \(X\) 之间的 MSE 损失,反映了 VAE 网络生成结果与输入数据的差异。

KL 散度意图获知编码器输出的变量分布与标准正态分布的差距,网络训练时期望这个差距越来越小。

KL 散度项的推导

KL 散度(Kullback-Leibler Divergence)是用来度量两个概率分布相似度的指标。

针对离散的随机变量 \(x\),假设有两种概率分布 \(P\)\(Q\),则 \(P\)\(Q\) 的 KL 散度为:

\[D_{KL}(P||Q)=\sum_{i}p(x_i)\ln \frac{p(x_i)}{q(x_i)} \]

可见,若两种分布完全一致,KL 散度达到最小值 0。\(P\)\(Q\) 差距越大,KL 散度也就越大。

针对 VAE 损失函数中的 \(\text{KL}(N(\mu,\sigma^2),N(0,1))\) 项,\(N(\mu,\sigma^2)\)\(N(0,1)\) 的概率密度函数分别为

\[p(x)=\frac{1}{\sqrt{2\pi \sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \]

\[q(x)=\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}} \]

带入到 KL 散度计算公式,可化简得到

\[\text{KL}(N(\mu,\sigma^2),N(0,1))=\frac{1}{2}(\mu^2+\sigma^2-2\ln (\sigma)-1) \]

VAE 的实现细节

重参数化技巧

VAE 的 encoder 从正态分布中采样数据,这个过程是不可微的。这导致梯度会在此不可传递,网络无法训练。

重参数化技巧(reparameterization trick)使得我们可以从带可变参数 \(\theta\) 的分布 \(p_\theta(x)\) 中采样,保留梯度信息。

具体来说,我们不直接从 \(N(\mu,\sigma^2)\) 中采样,而是先从 \(N(0,1)\) 采样,再用 \(\mu\)\(\sigma\) 对采样结果进行线性变换。这不就相当于也是 \(N(\mu,\sigma^2)\) 的采样结果。

对数方差

VAE 的 encoder 输出一组均值和方差,以供采样。

方差必须为非负数,而网络的输出可正可负。将网络输出视为对数方差 \(\ln \sigma\) 会更方便。

避免后验塌缩

后验坍塌(Posterior Collapse)问题,可以说是 VAE 独有的烦恼。

简单来说,若 decoder 足够强,强到能从纯噪声中生成理想结果,encoder 就失效了。具体体现是损失函数中的 KL 散度项几乎为 0,整体 loss 降不下去。

相关各种解决方法可以参考 这个文章

参考来源

posted @ 2024-03-25 23:27  倒地  阅读(194)  评论(0编辑  收藏  举报