[学习笔记] Gibbs Sampling

Gibbs Sampling

Intro

Gibbs Sampling 方法是我最近在看概率图模型相关的论文的时候遇见的,采样方法大致为:迭代抽样,最开始从随机样本中抽样,然后将此样本作为条件项,按条件概率抽样,每次只从一个维度考虑,当所有维度均采样完,开始下一轮迭代。

Random Sampling

原理

这是基于反函数的采样方法。假设我们已知均匀分布如何采样,即能够生成一个0-1内的随机数,我们可以将均匀分布的采样结果映射到其他分布上去,使得我们经过映射得到目标分布的采样值。

这个映射是将均匀分布的值作为目标分布的概率分布函数,即

\[Y = F(X) = \int_x f(x)dx \]

\(Y\)是均匀分布的随机变量,其取值是0到1,很显然我们可以通过概率分布函数的逆函数求的目标分布的随机变量值:

\[x = F^{-1}(y) \]

例子

已知\(Y \sim U(0,1)\)\(Y=F(X)=X^2\),求\(X\)的采样。

反函数:

\[X = \sqrt{Y} \]

import torch
y = torch.rand()
x = torch.sqrt(y)

有一个问题就是,当 $ p(x) $ 的累积分布函数的反函数无法计算的时候,如何采样呢?

这时候就要用到一些采样的策略,比如拒绝采样、重要性采样、Gibbs采样等等。

Rejection Sampling

原理

用已知分布\(q\)(提议分布),去覆盖目标分布\(p\),在提议分布上采样,可以得到两个分布在该点的概率密度值,这个时候再在\([0,kq]\)上均匀采样一个随机变量\(z\),如果随机变量落在\([0,p]\)上则接受。最终采样的区域在二维上表现为下图红色区域以下x轴以上的部分,是将概率密度图像当成一个二维形式的变量来看的,最终舍弃纵轴的随机变量只要x轴的。其等价于其他资料中的在\([0,1]\)上采样随机变量\(z\),然后\(z\leq\frac{p}{kq}\)则接受。

然而拒绝采样要求提议分布和原始分布比较接近,这样采样率才会比较高,否则这个采样方法就是低效的。
例子

\(X \sim U(0,1)\)\(Y=X^2\),求\(Y\)的采样。

提议分布可以用\(X\),因为其平方肯定小于\(X\)

可以得到

\[f(x) = 1 \]

\[f(y)=\frac{1}{2\sqrt{y}} \]

import torch
num_samples = 1000
samples = []
for i in range(num_samples):
	x = torch.rand() # 采样提议分布
	fx = 1                     # x的概率密度
	fy = 1 / 2 / torch.sqrt(x) # y的概率密度
    z = torch.rand(0, fx) # 随机变量z
    if z < fy: # 条件
        samples.append(x) 

Importance Sampling

原理

重要性采样的目的不是采样,而是以更小的方差估计目标分布的期望,将目标分布的期望转化为了在提议分布下新的被积函数的期望,转化的好处是在提议分布下的采样值防差更小。

对于期望用蒙特卡洛方法求解:

\[\int f(z) p(z) d z \approx \frac{1}{L} \sum_{l=1}^L f\left(z^{(l)}\right) \]

此时如果积分是一个反常积分,比如\(f(z)\)是一个无界函数,期望的估计会非常不准确,比如\(f(z) = z^{-a}e^z\)\(z=0\)附近采样将会导致蒙特卡洛采样的方差非常大。通过引入提议分布,可以将原被积函数转化为一个有界函数,从而降低蒙特卡洛采样的方差。

\[\begin{aligned}\mathrm{E}[f] & =\int f(z) p(z) d z \\& =\int f(z) \frac{p(z)}{q(z)} q(z) d z \\& \approx \frac{1}{L} \sum_{l=1}^L \frac{p\left(z^{(l)}\right)}{q\left(z^{(l)}\right)} f\left(z^{(l)}\right) \\\omega\left(z^{(l)}\right) & =p\left(z^{(l)}\right) / q\left(z^{(l)}\right)\end{aligned} \]

例子

求积分\(\int_0^1x^{-a}e^xdx\).

令提议分布为\(q(x) = \frac{x^{-a}}{1-a}\)

则此时积分变为了\(\int_0^1(1-a)e^x q(x)dx = \mathbb{E}_q[(1-a)e^x]\),被积函数有界。

将其转化为蒙特卡洛采样形式:

\[\frac{1}{L}\sum_{i=1}^L(1-a)e^{x_i} \]

其中\(x_i \sim q(x)\),对\(q(x)\)的采样可以利用上面的反函数采样法。

\(y\sim U[0,1]\),\(Y=\int_0^Xq(X)dX=\frac{X^{-a+1}}{(1-a)^2}\)

得反函数:

\[X = ((1-a)^2Y)^{\frac{1}{1-a}} \]

import torch
num_samples = 1000
expectation = 0.
for i in range(num_samples):
	y = torch.rand()
	x = ((1-a)**2 * y)**(1/(1-a))
	expectation += (1-a) * torch.exp(x) / num_samples

Gibbs Sampling

原理
假设有一随机向量\(x = (x_1,x_2,...,x_d)\),其中d表示他有d维,每一维是一随机变量,且并不是我们常见的相互独立前提。那么,如果我们已知这个随机向量的概率分布,我们如何从这个分布中进行采样呢?

显然想要从多元分布的联合概率分布中直接抽样是相当困难的,而Gibbs Sampling就是一种简单而且有效的采样方法。吉布斯采样的大致步骤如下:

从一个随机的初始化状态\(x^{(0)}=[x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}]\)开始,对每个维度单独进行采样,其采样顺序大致如下:

\[x_1^{(1)} \thicksim p(x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\x_2^{(1)} \thicksim p(x_2|x_1^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\\vdots \\x_d^{(1)} \thicksim p(x_d|x_1^{(0)},x_2^{(0)},\cdots,x_{d-1}^{(0)}) \\\vdots \\x_1^{(t)} \thicksim p(x_1|x_2^{(t-1)},x_3^{(t-1)},\cdots,x_d^{(t-1)}) \\\vdots\\x_{d}^{(t)} \thicksim p(x_d|x_1^{(t-1)},x_2^{(t-1)},\cdots,x_{d-1}^{(t-1)}) \\ \]

遵从上面的采样步骤,我们最终能够采样得到所需要的高维分布的样本。需要注意的是,迭代的最开始采样得到的样本并不是完全满足所需要的分布的样本,因为采样之初采样的分布是提议分布,一般是均匀分布,而Gibbs Sampling的过程更像是一个单步迭代的过程,这使我想起了EM算法,都是一样的,一步一步去迭代达到最终结果。

我在网上找到了一个能够描述这个过程的图片:

如上图所示,右图是我们需要的分布,左边是迭代的过程,最开始抽样的点0和1都是均匀分布抽样得到的,而越到后面,抽样的点都越满足我们右边的分布,所以这个过程可以说明Gibbs Sampling抽样的过程是可行的。

还有下面这张图,也差不多:

例子

Gibbs Sampling我是从一篇图像合成的论文中看到并有所了解的,文章基于MRF,使用神经网络去拟合条件分布\(p(x_i|x_{-i})\),其中\(x_{-i}\)表示除了第i个属性的其他属性。

具体到图像中来,\(x_i\)就是第i个位置的像素点的像素值,而\(x_{-i}\)描述的就是除了这个点以外的其他所有点,因此上式的概率分布就是一个条件分布。

使用神经网络可以拟合出这个分布来,那么如何去生成图片又是一个问题。

文章给出的解决方案就是Gibbs Sampling,先从随机噪声开始,逐像素进行生成,第一次迭代完成将生成一张图片,那么第二次第三次依次可以使用上一次迭代完前生成的图片进行迭代生成下一次,当迭代次数足够多的时候,即我们认为达到了平稳分布,这个时候生成的图片就是服从该分布的图片了。

原文参见:

原文链接

具体的,我给出下面的代码:

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from PIL import Image
import glob
import random
import cv2 as cv
class MConv(nn.Conv2d):
    '''
    mask_type A or B
    A : the center is zero
    B : the center is not zero
    '''
    def __init__(self,mask_type,*args,**kwargs):
        super(MConv,self).__init__(*args,**kwargs)
        assert mask_type in ["A","B"]
        self.mask_type = mask_type
        self.register_buffer('mask', self.weight.data.clone())
        _,_,h,w = self.weight.size()
        self.mask.fill_(1)
        self.mask[:,:,h//2,w//2 + (mask_type == 'B'):] = 0
        self.mask[:,:,h//2+1:,:] = 0
        
    def forward(self,x):
        self.weight.data *= self.mask
        return super(MaskedConv2d,self).forward(x)
    
    
class DoublePixelCNN(nn.Module):
    def __init__(self,fm,kernel_size = 7,padding = 3):
        super(DoublePixelCNN, self).__init__()
        self.net1 = nn.Sequential(
                MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                #nn.Conv2d(fm, 256, 1)
        ) 
        self.net2 = nn.Sequential(
                MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                #nn.Conv2d(fm, 256, 1)
        ) 
        
        self.conv1x1 = nn.Conv2d(fm*2, 256, 1)
    def forward(self,x):
        x1 = self.net1(x)
        x2 = self.net2(x.flip(dims = [-1,-2]))
        x = torch.cat([x1,x2.flip(dims = [-1,-2])],dim = 1)
        x = self.conv1x1(x)
        return x

if __name__ == "__main__":
	tr =       data.DataLoader(datasets.MNIST(root="/media/xueaoru/Ubuntu/dataset/data",transform=transforms.ToTensor(),),
                     batch_size=64, shuffle=True, num_workers=12, pin_memory=True)
    net = DoublePixelCNN(128)
    net.cuda()
    sample = torch.rand(64,1,k,k).cuda()
    optimizer = optim.Adam(net.parameters(),lr = 0.0001)
    for epoch in range(1000):
        net.train()
        running_loss = 0.
        for input,_ in tqdm(tr):
            #print(input.size())
            input = input.cuda()
            #target = target.cuda()
            target = (input.data[:,:] * 255).long() # (b,3,h,w)
            # net(input) (b,256,3,h,w)
            loss = F.cross_entropy(net(input), target) # 计算的是每个像素的二分类交叉熵
            running_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("training loss: {:.8f}".format(running_loss / len(tr)))
        if epoch % 5 == 0:
            torch.save(net.state_dict(),open("./{}.pth".format(epoch),"wb"))
            #sample.fill_(0)
            net.eval()
            with torch.no_grad():
                for t in tqdm(range(300)):
                    for i in range(k):
                        for j in range(k):
                            out = net(sample) # (b,256)
                            probs = F.softmax(out[:, :, i ,j],dim = 1).data # (b,c) = (16,256)
                            sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
                
                utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)
    			sample = torch.rand(64,1,k,k).cuda()

由于这个方法采样时间极其缓慢,所以我生成的图片尺度比较小,训练周期也比较短,只是做个demo使用。

posted @ 2019-12-24 16:59  aoru45  阅读(6623)  评论(0编辑  收藏  举报