机器学习课程笔记5

上一篇 -> 机器学习课程笔记 4。

生成式对抗网络(Genrative Adversarial Network)

image

  • 输入:X 和简单分布下的随机值Z(每次都随机);
  • 输出:复杂分布的 y。

1. 为什么要用 GAN ?

image

1.1 需求

根据视频中怪物的前几帧的移动,预测其下一帧的运动。

1.2 结果

普通网络学习后,会出现不符合移动的分裂现象,即同一个怪物分成两个,向不同的方向移动,如下图。
image

1.3 原因

根据之前学习的帧,相同的位置,怪物存在两种不同方向的位移,所以模型为了“准确”表现学习的效果,在这个出现时,怪物的两种位移同时展现。
image

1.4 解决方法

GAN神经网络。
image

1.5 解决的本质问题

同样的输入可以有多种输出,如下图画画和问答。
image
image

2. 怎么训练 GAN ?

2.1 选择一个简单分布

如利用简单的正态分布作为随机Z输入到网络,训练得到复杂分布的动画人脸。
image

2.2 训练一个鉴别器(Discriminator)

  • 鉴别器:仍是一个神经网络,可以用CNN等。
  • 作用:输入图片,输出Scalar。
  • 量化指标:Scalar越大表示图片越接近真实标签,反之亦然。
    image

2.3 Generator 和 Discriminator

  • 本质:“物竞天择”。
  • Discriminator 会对 Generator 生成的结果进行打分筛选。
  • 每次筛选的打分影响 Generator 后续生成的结果。
  • 以此类推,进行“Adversarial 对抗”,直到训练完成。
    image
    image

2.4 训练步骤

  • 固定 Generator G,训练 Discriminator D。
    让 D 学习 G 生成的图片和真实图片的差别。
    image

  • 固定 Discriminator D,训练 Generator G。
    利用 D 给的分数,G 学习如何“骗”过 D。
    image

  • 总体
    image

3. GAN 背后的原理

image
image

  • 找到 可以使预测和真实值之间 Divergence(分歧)最小化的一组参数。

3.1 怎么找?

  • 获取分布
    • 从真实数据采样处得到真实分布
    • 从 G 对采样的简单分布生成的结果处得到预测分布

image

  • 训练 Discriminator
    • 最大化分辨器 Object Function 的结果
    • 因为是\(log\)和加法,即 \(V(G, D)\) 中真实数据的分数 \(D(y)\) 越高, 预测数据的分数 \(1 - D(y)\) 中的 \(D(y)\) 越低,越好。
    • \(V(G, D)\) 也可以理解为负交叉熵,即求交叉熵的最小值。

image
image

  • 通过观察 Object Function 获取 Divergence
    • Object Function的值和分歧值有相关性,具体如下图所示。
    • 也意味着求分歧可以通过求解最小化的max \(V(G, D)\),即 JS Divergence。

image
image
image
联系之前的具体步骤:
image

  • 那么现在每一种 Object Function 就可以形成一种 Divergence。

image

4. GAN的小技巧

  • GAN 不好训练
  • JS Divergence 不适合
    原因如下图
    image
    主要就是预测和真实的数据重叠部分很少。
    而 JS Divergence 需要数据重叠多,效果才好,具体如下图。
    image
    image
    image

4.1 Wasserstein Distance

抛弃 JS,选用新的 Divergence指标 ---> Wasserstein Distance。

  • 调整预测的分布使其接近真实数据的分布。

image
image

  • 选取最合适的调整方案进行调整,即调整所用的平均distance最小。

image

4.2 WGAN

posted @ 2022-10-26 09:53  bok_tech  阅读(46)  评论(0)    收藏  举报