GMM算法和Python简单实现

GMM算法和Python简单实现

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u010866505/article/details/77897632
GMM算法

第一章引子

假设放在你面前有5篮子鸡蛋,每个篮子有且仅有一种蛋,这些蛋表面上一模一样,就是每一种蛋涵盖有且只有一种维生素,分别是A、B、C、D、E。这个时候,你需要估计这五个篮子的鸡蛋的平均重量μ。 首先有个总的假设: 假设每一种维生素的鸡蛋的重量都服从高斯分布。 这个时候,因为每个篮子的鸡蛋包含有且只有一种,并且彼此之间相同的维生素,即每个篮子的鸡蛋都服从相同的分布,这个时候可以用极大似然估计去估计每一种维生素鸡蛋的平均重量。

现在问题来了: 我把那5种鸡蛋混在一起,这个时候要你去估计这5中鸡蛋的平均重量和方差? 仍旧假设:每一种维生素的鸡蛋的重量都服从高斯分布 这个时候有两个参数需要估计均值μ和方差σ。 你从5中鸡蛋里面去拿一种鸡蛋,每种鸡蛋被拿到的概率是呈一定分布的(不一定就是均匀分布,然后概率是1/5)。假设第j中鸡蛋被拿到的概率是φj

因此你有三个未知量,即隐含量: φj、均值μ和方差σ。

现在是如果你知道类别z(j),你就能根据极大似然估计计算其他两个隐含量。令每个鸡蛋的重量用x(i)表示,x(i)服从高斯分布。 。(x(i)|z(i)=j)~N(μ,σ2)

高斯混合模型(gaussian mixture model-GMM)就是上述问题的抽象模型,即对于每一个样本x(i),先从k个类别中按某种分布抽取一个z(i),然后从这个类别下的高斯分布中生成一个样本x(i)

第二章 GMM数学推导

假设独立样本集是{x(1)......x(m)},这些样本是从k个高斯分布的数据里面抽取出来的。从k不同分布中抽取某一类别是呈某种分布,假设抽取到类别z(i)的概率是p(z(i)=j)= φj。因此这里面会有三个隐含变量: φj、均值μ和方差σ,令这三个变量构成一个集合θ.

则利用极大似然估计θ就是我们非常熟悉的极大似然估计函数:

 

 

 

 

 

即:

 

 

 

就连乘变成连加,取对数有:

 

 

接下来,我们利用EM算法来估计这些参数。EM算法请参看(http://blog.csdn.net/u010866505/article/details/77877345).

1.       先假设我们知道样本的类别z(i),然后计算期望E,即后验概率:

 

 

2.       然后就是M-step:

 

 

求偏导可以计算均值 μj:

 

 
 

在公式(1)中,如果均值和方差固定的话,那么(1)式可以简化成:

 

 

为了计算优化(3)式,而 又满足一定的条件,即 ,如果你知道大名鼎鼎的支持向量机(SVM)的优化目标函数是如何得到的话?这里也就明白了。拉格朗日乘子。构造拉格朗日函数如下:

 
 
求导:
 
 

令(5)式=0,计算有:

 

上面(6)式两边做如下处理:

 



                                                                                

因此(6)式得到如下变换:

 

 

接下来就是计算方差:对(1)公式计算方差 的偏导:

 

 

接下来要计算方差的偏导,因此(9)式中括号内的第一项和最后一项可以不用考虑。于是有:
 
 
 

令(10)式等于0,可以计算得到方差的公式如下:

 

 

(2)+(8)+(11)就是我们估计参数的最后的解。

 第三章 GMM代码-python实现

如下是GMM算法的简单实现,
  1.  
    #! /usr/bin/env python
  2.  
    #! -*- coding=utf-8 -*-
  3.  
     
  4.  
    #模拟两个正态分布的参数
  5.  
     
  6.  
    from numpy import *
  7.  
    import numpy as np
  8.  
    import random
  9.  
    import copy
  10.  
     
  11.  
    SIGMA = 6
  12.  
    EPS = 0.0001
  13.  
    #均值不同的样本
  14.  
    def generate_data():
  15.  
    Miu1 = 2
  16.  
    Miu2 = 4
  17.  
    sigma1 = 1
  18.  
    sigma2 = 2
  19.  
    alpha1 = 0.4
  20.  
    alpha2 = 0.6
  21.  
    N = 5000
  22.  
    N1 = int(alpha1 * N)
  23.  
    X = mat(zeros((N,1)))
  24.  
    for i in range(N1):
  25.  
    temp = random.uniform(0,0.5)
  26.  
    X[i] = temp * sigma1 + Miu1
  27.  
    for i in range(N-N1):
  28.  
    temp = random.uniform(0,0.5)
  29.  
    X[i+N1] = temp * sigma2 + Miu2
  30.  
    return X
  31.  
     
  32.  
    #EM算法
  33.  
    def my_GMM(X):
  34.  
    k = 2
  35.  
    N = len(X)
  36.  
    Miu = np.random.rand(k,1)
  37.  
    Posterior = mat(zeros((N,k)))
  38.  
    sigma = np.random.rand(k,1)
  39.  
    sigma[0]=1
  40.  
    #sigma[1]=2
  41.  
    alpha = np.random.rand(k,1)
  42.  
    alpha[0] = 0.1
  43.  
    alpha[1] = 0.9
  44.  
    dominator = 0
  45.  
    numerator = 0
  46.  
    #先求后验概率
  47.  
    print sigma
  48.  
    for it in range(1000):
  49.  
    for i in range(N):
  50.  
    dominator = 0
  51.  
    for j in range(k):
  52.  
    dominator = dominator + np.exp(-1.0/(2.0*sigma[j]) * (X[i] - Miu[j])**2)
  53.  
    #print -1.0/(2.0*sigma[j]),(X[i] - Miu[j])**2,-1.0/(2.0*sigma[j]) * (X[i] - Miu[j])**2,np.exp(-1.0/(2.0*sigma[j]) * (X[i] - Miu[j])**2)
  54.  
    #return
  55.  
    for j in range(k):
  56.  
    numerator = np.exp(-1.0/(2.0*sigma[j]) * (X[i] - Miu[j])**2)
  57.  
    Posterior[i,j] = numerator/dominator
  58.  
    oldMiu = copy.deepcopy(Miu)
  59.  
    oldalpha = copy.deepcopy(alpha)
  60.  
    oldsigma = copy.deepcopy(sigma)
  61.  
    #最大化
  62.  
    for j in range(k):
  63.  
    numerator = 0
  64.  
    dominator = 0
  65.  
    for i in range(N):
  66.  
    numerator = numerator + Posterior[i,j] * X[i]
  67.  
    dominator = dominator + Posterior[i,j]
  68.  
    Miu[j] = numerator/dominator
  69.  
    alpha[j] = dominator/N
  70.  
    tmp = 0
  71.  
    for i in range(N):
  72.  
    tmp = tmp + Posterior[i,j] * (X[i] - Miu[j])**2
  73.  
    #print tmp,Posterior[i,j],(X[i] - Miu[j])**2
  74.  
    sigma[j] = tmp/dominator
  75.  
    print tmp, dominator, sigma[j]
  76.  
    if ((abs(Miu - oldMiu)).sum() < EPS) and \
  77.  
    ((abs(alpha - oldalpha)).sum() < EPS) and \
  78.  
    ((abs(sigma - oldsigma)).sum() < EPS):
  79.  
    print Miu,sigma,alpha,it
  80.  
    break
  81.  
     
  82.  
    if __name__ == '__main__':
  83.  
    X = generate_data()
  84.  
    my_GMM(X)


参考资料
posted on 2019-12-03 16:58  曹明  阅读(1317)  评论(0)    收藏  举报