条件GAN论文简单解读

    条件GAN(Conditional Generative Adversarial Nets),原文地址为CGAN

Abstract

    生成对抗网络(GAN)是最近提出的训练生成模型(generative model)的新方法。在本文中,我们介绍了条件GAN(下文统一简称为CGAN),简单来说我们把希望作为条件的data y同时送入generator和discriminator。我们在文中展示了在数字类别作为条件的情况下,CGAN可以生成指定的MNIST手写数字。我们同样展示了CGAN可以用来学习多形态模型(multi-modal model),我们提供了一个image-tagging的应用,其中我们展示了这种模型可以产生丰富的tags,这些tags并不是训练标签的一部分。

1. Introduction

    GAN最近被提出作为训练生成模型的替代框架,用来规避有些情况下近似复杂概率的计算的困难。     GAN的一个重要优势就是不需要计算马尔科夫链(Markov chains),只需要通过反向传播算法计算梯度,在学习过程中不需要进行推断(inference),一系列的factors和interactions可以被轻易地加入到model当中。     更进一步地,就像[8]中显示的那样,CGAN可以产生state-of-the-art的对数似然估计(log-likelihood)和十分逼真的样本。     在非条件的生成模型中,我们没法控制生成什么样模式的样本。然而,通过给model增加额外的信息,我们可以引导模型生成数据的方向。这样的条件可以建立在类别标签,或者[5]展示的图像修复的部分数据,甚至可以是不同模式的数据上。     本文展示了应该如何构建CGAN。我们展示了CGAN在两个数据集上的结果,一个是以类别标签作为条件的MNIST数据集;还有一个是建立在MIR Flickr 25,000 dataset上的多模态学习(multi-modal learning)。

2. Related Work

2.1 图像标签的多模态学习

    尽管最近监督神经网络(特别是卷积网络)取得了巨大的成功,但是将这种模型扩展到有非常大的预测输出类别的问题上仍然面临着巨大的挑战。第二个问题是,当今的大部分工作都主要集中在学习输入到输出的一对一的映射。然而很多有趣的问题可以考虑为概率上的一对多的映射。比如说在图片标注问题上,对于一个给定的图片可能对应了多个标签,不同的人类标注者可能会使用不同的(但通常是相似的或者是相关的)词汇来描述相同的一幅图片。     解决第一个问题的一种方式是从其他的模式中施加额外的信息,比如说通过语言模型来学习词汇的向量形式的表达,其中几何上的关系对应了语义上的相关。在这样的空间(映射之后的向量空间)做预测时,一个很好的性质时,即使我们的预测错误了,但是仍然和真实的答案很接近(比如说预测是"table"而不是"chair"),还有一个优势是,我们可以自然地对即使在训练时没有见过的词汇做generalizations prediction,因为相似的向量语义上也是相似的。[3]的工作显示即使是一个从图像特征空间到单词表达空间(word vector)的线性映射都可以提高分类的性能。     解决第二个问题的一种解决办法是使用条件概率生成模型,输入是作为条件变量,一对多的映射被实例化为一个条件预测分布。     [16]对第二个问题采用了和我们类似的办法,他们在MIR Flickr 25,000 dataset上训练了一个深度玻兹曼机。     除此之外,[12]的作者展示了如何训练一个有监督的多模态自然语言模型,这样可以为图片生成描述的句子。

3. Conditional Adersarial Nets(条件对抗网络)

3.1 Generative Adervasarial Nets

    GAN是最近提出的一种新颖的训练生成模型的方式。它包含了两个“对抗”模型:生成模型G捕获数据分布,判别模型D估计样本来自训练数据而不是G的概率。G和D都可以是非线性的映射函数,比如多层感知机模型。     为了学习生成器关于data x的分布$p_g$,生成器构建了一个从先验噪声分布$p_z(z)$到数据空间的映射$G(z;\theta_g)$。判别器$D(x;\theta_d)$输出了一个单一的标量,代表x来自训练样本而不是$p_g$的概率。     G和D是同时训练的:我们调整G的参数来最小化$log(1-D(g(Z)))$,然后调整D的参数来最小化$log(D(X))$,他们就像如下的两人的最小最大化博弈(two player min-max game),价值函数(value function)为$V(G,D)$:

3.2 Conditional Adersarial Nets

    如果生成器和判别器都基于一些额外的信息y的话,GAN可以扩展为一个条件模型。y可以是任何形式的辅助信息,比如说类别标签或者其他模式的数据。我们可以通过增加额外的输入层来将y同时输入生成器和判别器,来实施条件模型。     在生成器中,先验的噪声输入$p_z(z)$和y被结合成一个连接隐藏表达(joint hidden representation),对抗训练的框架为组成隐藏表达(compose of hidden representation)提供了相当大的灵活性。     在判别器中,x和y被作为输入送入判别函数(再一次地,比如可以是一个MLP,多层感知器)。     Two player minimax game的目标函数如公式(2):
    图1展示了一个简单的条件对抗网络的架构。
图1 条件对抗网络

4. Experimental Results

4.1 Unimodal(单一模式)

    我们以类别标签作为条件在MNIST数据集上训练了一个对抗网络,类别标签是作为one-hot vectors的形式。     在生成网络中,100维的噪声先验分布是从unit hypercube(单位超方体)的均匀分布采样得到的。z和y都是映射到带有relu激活函数的hidden layers,隐藏层节点数分别为200和1000,然后二者的输出相结合形成一个节点数为1200的带有relu激活函数的hidden layer,最后是一个sigmoid unit hidden layer作为输出,生成784维的MNIST samples。
    判别器将x映射到一个有240 units and 5 pieces的maxout layer[6],y映射到一个有50 units and 5 pieces 的maxout layer。这两个hidden layers在被送入sigmoid layer之前都被映射到一个有240 units and 4 pieces 的joint maxout layer。(判别器的准确的架构不是特别重要,只要有sufficient power即可;我们发现对于这个任务maxout units非常合适)。     模型的训练使用SGD,mini-batch size 为100,初始化的学习率为0.1,指数衰减因子为1.00004,最终的学习率为0.000001。momentum参数初始化为0.5,最终增加到0.7。generator和discriminator都需要使用dropout,dropout rate为0.5。在validation set 上的最佳对数似然估计被作为停止点(early stop)。     表1显示了对于MNIST的test data的Gaussian Parzen window对数似然估计。从10个类别的每一个类别采样共得到1000个samples,然后使用Gaussian Parzen window来拟合这些samples。然后我们使用Parzen window 分布来估计测试集的对数似然。([8]详细介绍了怎么做这种估计)。     条件对抗网络的结果显示了,我们的实验结果和基于其他网络得到的结果相近,但是比其中的几种方法更加优越——包括非条件对抗网络。我们展示这种优越性更多是基于概念上的,而不是具体的功效,我们相信,未来如果对超参数和模型架构进行更深入的探索,条件模型可以达到甚至超过非条件模型的结果。     图2显示了一些生成的样本,每一行是基于一个label生成的样本,而每一列则代表了生成的不同样本。
图2 生成的MNIST手写数字,每一行是以一个label作为条件

4.2 Multimodal(多模态)

    像Flickr这样的图像网站,是图像以及用户为图像生成的额外信息(user-generated metadata,UGM)的有标记数据的丰富来源——特别数用户提供的标签。     用户提供的标记信息与经典的图像标签不一样的地方在于,用户提供的标记信息内容更加丰富,语义上也更加接近人类用自然语言对于图像的描述,而不仅仅是识别出图像中有什么东西。UGM中同义词很普遍,不同的用户对相同的图像内容可能用不同的词汇去描述,因此,找到一种对这些标签进行标准化的有效方式是非常重要的。概念上的词向量是非常有用的,因为表达成为词向量之后,语义相近的词向量在距离上也是相近的。在本节当中,我们展示了图像的自动标记,可以带有多个预测标签,我们基于图像特征使用条件对抗网络生成(可能是多模态)标签向量的分布。     对于图像特征来说,我们采用和[13]类似的方法,在带有21000个标签的全部ImageNet数据集上预训练了一个卷积网络。我们使用了卷积网络最后一层带有4096个units的全连接层作为图像的特征表达。     对于单词表达来说,我们从[YFCC100M](http://webscope.sandbox.yahoo.com/catalog)数据集获取了用户标签,标题以及图像描述的语料库。在对文本进行预处理以及清洗之后,我们训练了一个skip-gram model,word vector的size是200。我们从词典当中丢弃了出现次数少于200次的词汇。最后词典的大小是247465。     我们在训练对抗网络过程中保持卷积网络和语言模型(language model)固定。未来我们将会探索,将反向传播同时应用于对抗网络,卷积网络和语言模型。     在实验的过程中,我们使用了MIR Flickr 25,000 数据集,并且使用了如上所述的卷积网络和语言模型提取了图像特征和标签(词向量)特征。没有任何标签的图像被我们舍弃了,注释被看做是额外的标签。前15万的样例被作为训练样本。有多个标签的图像,带有每一个标签的图像分别被看做一组数据。     评估过程,对于每个图像我们生成了100个samples,并且使用余弦距离找出了最相近的20个单词。然后我们选取了100个samples中最常出现的10个单词。表4.2展示了一些用户关联生成的标签和注释以及生成的标签。     表现最佳的条件对抗网络的生成器接收size为100的高斯噪声作为先验噪声,然后将它映射到500维的relu层,然后将4096层的图像特征映射到2000维的relu layer,这些层都被映射到一个200维的线性layer的然后连接表达,最后输出生成的词向量。     鉴别器由对于词向量500维的relu layer,对图像特征1200维的relu layer组成,然后是一个带有1000个units和3pieces的maxout layer,最后送入sigmoid单元得到输出。     模型的训练使用了随机梯度下降(SGD),batch size =100,初始学习率为0.1,指数衰减率为1.00004,最后学习率下降到0.000001。同时模型也使用了momentum(动量加速),初始值为0.5,最后上升到0.7。生成器和鉴别器都使用了dropout,dropout rate 为0.5。     超参数以及模型架构由交叉验证还有混合了手工以及的网格搜索的方法所得到。

5. Feature Work

    本文显示的结果非常初步,但是它展示了条件对抗网络的潜力,同时也为有趣且有用的应用提供了新的思路。在未来进一步的探索当中,我们希望展示更加丰富的模型以及对于模型表现、特性更加具体深入的分析。同时在当前的实验中,我们仅仅使用了每个单独的标签,我们希望可以通过一次使用多个标签取得更好的结果。     另外一个显然未来可以探索的方向是我们可以将对抗网络和语言模型结合到一起训练。[12]的工作显示了我们可以学习到针对特定任务的语言模型。

References

[1] Bengio, Y., Mesnil, G., Dauphin, Y., and Rifai, S. (2013). Better mixing via deep representations. In ICML’2013. [2] Bengio, Y., Thibodeau-Laufer, E., Alain, G., and Yosinski, J. (2014). Deep generative stochastic networks trainable by backprop. In Proceedings of the 30th International Conference on Machine Learning (ICML’14). [3] Frome, A., Corrado, G. S., Shlens, J., Bengio, S., Dean, J., Mikolov, T., et al. (2013). Devise: A deep visual-semantic embedding model. In Advances in Neural Information Processing Systems, pages 2121–2129. [4] Glorot, X., Bordes, A., and Bengio, Y. (2011). Deep sparse rectifier neural networks. In International Conference on Artificial Intelligence and Statistics, pages 315–323. [5] Goodfellow, I., Mirza, M., Courville, A., and Bengio, Y. (2013a). Multi-prediction deep boltzmann machines. In Advances in Neural Information Processing Systems, pages 548–556. [6] Goodfellow, I. J., Warde-Farley, D., Mirza, M., Courville, A., and Bengio, Y. (2013b). Maxout networks. In ICML’2013. [7] Goodfellow, I. J., Warde-Farley, D., Lamblin, P., Dumoulin, V., Mirza, M., Pascanu, R., Bergstra, J., Bastien, F., and Bengio, Y. (2013c). Pylearn2: a machine learning research library. arXiv preprint arXiv:1308.4214. [8] Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. (2014). Generative adversarial nets. In NIPS’2014. [9] Hinton, G. E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. (2012). Improving neural networks by preventing co-adaptation of feature detectors. Technical report, arXiv:1207.0580. [10] Huiskes, M. J. and Lew, M. S. (2008). The mir flickr retrieval evaluation. In MIR ’08: Proceedings of the 2008 ACM International Conference on Multimedia Information Retrieval, New York, NY, USA. ACM. [11] Jarrett, K., Kavukcuoglu, K., Ranzato, M., and LeCun, Y. (2009). What is the best multi-stage architecture for object recognition? In ICCV’09. [12] Kiros, R., Zemel, R., and Salakhutdinov, R. (2013). Multimodal neural language models. In Proc. NIPS Deep Learning Workshop. [13] Krizhevsky, A., Sutskever, I., and Hinton, G. (2012). ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25 (NIPS’2012). [14] Mikolov, T., Chen, K., Corrado, G., and Dean, J. (2013). Efficient estimation of word representations in vector space. In International Conference on Learning Representations: Workshops Track. [15] Russakovsky, O. and Fei-Fei, L. (2010). Attribute learning in large-scale datasets. In European Conference of Computer Vision (ECCV), International Workshop on Parts and Attributes, Crete, Greece. [16] Srivastava, N. and Salakhutdinov, R. (2012). Multimodal learning with deep boltzmann machines. In NIPS’2012. [17] Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., and Rabinovich, A. (2014). Going deeper with convolutions. arXiv preprint arXiv:1409.4842.
posted @ 2018-06-05 15:49  lyrichu  阅读(8252)  评论(0编辑  收藏  举报