隐马尔可夫模型及的评估和解码问题

HMM介绍

Hidden Markov Models是一种统计信号处理方法,模型中包含2个序列和3个矩阵:状态序列S、观察序列O、初始状态矩阵P、状态转移矩阵A、混淆矩阵B。举个例子来说明。

你一个异地的朋友只做三种活动:散步、看书、做清洁。每天只做一种活动。假设天气只有两种状态:晴和兩。每天只有一种天气。你的朋友每天告诉你他做了什么,但是不告诉你他那里的天气。

某一周从周一到周五每天的活动分别是{读书,做清洁,散步,做清洁,散步}----这就是观察序列O,因为你可以观察得到。

从周一到周五的天气依次是{晴,兩,晴,晴,晴}----这就是状态序列S,状态序列是隐藏的,你不知道。

根据长期统计,某天晴的概率是0.6,兩的概率是0.4。则$\pi=[0.6,0.4]$。

从晴转晴的概率是0.7,从晴转兩的概率是0.3。从兩转晴的概率是0.4,从兩转兩的概率是0.6。则$A=\left[\begin{array}{cc}0.7&0.3\\0.4&0.6\end{array}\right]$

天气晴时,散步的概率是0.4,看书的概率是0.3,做清洁的概率是0.3。天气兩时,散步的概率是0.1,看书的概率是0.4,做清洁的概率是0.5。则$B=\left[\begin{array}{cc}0.4&0.3&0.3\\0.4&0.1&0.5\end{array}\right]$

该模型和实际情况有明显不符的地方:用一简单的状态转移矩阵A来表示状态的转移概率的前提是t时刻的状态只跟t-1时刻的状态有关,而实际上今天的天气跟过去几天的天气都有关系,而且跟过去几天的晴朗程度、兩量大小都有关系;混淆矩阵B认为今天的活动只跟今天的天气有关系,实际上今天的活动跟过去几天的活动也有关系,比如过去一周都没有打扫房间,那今天做清洁的概率就大大增加。

模型介绍完了。

评估问题

隐马尔可夫模型中包含一个评估问题:已知模型参数,计算某一特定输出序列的概率。通常使用forward算法解决。 

比如计算活动序列{读书,做清洁,散步,做清洁,散步}出现的概率,就属于评估问题。

如果穷举的话,观察序列会有2^5种,需要分别计算它们出现的概率,然后找出概率最大的。

穷举法中有很多重复计算,向前算法就是利用已有的结果,减少重复计算。

算法借助于一个矩阵Q[LEN][M],其中M是所有状态的种数,Q[i][j]表示从第0天到第i天,满足观察序列,且第i天隐藏状态为Sj的所有可能的隐藏序列的概率之和。最终所求结果为Q[LEN-1][0]+...+Q[LEN-1][M-1],即最后一天,所有隐藏状态下,分别满足观察序列的概率值之和。

比如Q[0][0]=p(第一天做卫生 且 第一天晴)=p(天晴)*p(做卫生|天晴)=P[0]*B[0][2]=0.6*0.3=0.18

Q[0][1]=p(第一天做卫生 且 第一天下雨)=p(下雨)*p(做卫生|下雨)=P[1]*B[1][2]=0.4*0.5=0.2

Q[1][0]=p(第一天做卫生 且 第二天晴 且 第二天做卫生)

=p(第一天做卫生 且 第二天晴)*p(天晴的情况下做卫生)

=p{ p(第一天做卫生 且 第一天晴)*p(从天晴转天晴)+p(第一天做卫生 且 第一天下雨)*p(从下雨转天晴) }*p(天晴的情况下做卫生)

={ Q[0][0]*A[0][0] + Q[0][1]*A[1][0] } * B[0][2]

Q[1][1]= ……

…… ……

可以看到计算Q矩阵的每i行时都用到了第i-1行的结果。

解码问题

解码问题是:已知模型参数,寻找最可能的能产生某一特定输出序列O(LEN)的隐含状态的序列。通常使用Viterbi算法解决。 

观察序列长度为LEN,则隐藏状态序列长度也是LEN,如果采用穷举法,就有M^LEN种可能的隐藏状态序列,我们要计算每一种隐藏状态到指定观察序列的概率,最终选择概率最大的。

穷举法中有很多重复计算,Viterbi算法就是利用已有的结果,减少重复计算。

跟评估问题非常相似,不同点在于评估算的是和,解码算的是最大值。

Viterbi算法主要就是在计算一个矩阵Q[LEN][M],其中Q[i][j]表示从第0天到第i天,满足观察序列,且第i天隐藏状态为Sj的所有可能的隐藏序列的概率的最大值。另外还要建立一个矩阵Path[LEN][M],用来记录状态序列中某一状态之前最可能的状态。

举个例子,假如指定观察序列是{读书,做卫生,散步,做卫生,散步},求出现此观察序列最可能的状态序列是什么。

Q[0][0]=p(第一天读书 且 第一天晴)=p(天晴)*p(读书|天晴)

Path[0][0]=-1;

Q[0][1]=p(第一天读书 且 第一天下雨)=p(下雨)*p(读书|下雨)

Path[0][1]=-1;

关键是从第二天开始,Q[1][0]表示:满足“第一天读书 且 第二做卫生 且 第二天晴”的所有可能的隐藏序列的概率的最大值。那么满足“第一天读书 且 第二做清洁 且 第二天晴”的所有可能的隐藏序列有哪些呢?第二天是必须满足晴天的,第二天之前的状态可以任意变。则所有可能的隐藏序列就是“晴  晴”和“雨  晴”。实际上考虑第三天(及第三天以后)时,并不需要考虑“所有”可能的隐藏序列,而只需要考虑第二天的不同状态取值,这是因为马氏过程有无后效性--tm时刻所处状态的概率只和tm-1时刻的状态有关,而与tm-1时刻之前的状态无关。

Q[1][0]=

max{ p(第一天晴 且 第一天读书 且 第二天晴 且第二天做卫生) ,p(第一天下雨 且 第一天读书 且 第二天晴 且 第二天做卫生) }

=max{ p(第一天读书 且 第一天晴)*p(天晴转天晴),p(第一天读书 且 第一天下雨)*p(下雨转天晴) } * p(做卫生|天晴)

=max{ Q[0][0]*A[0][0],Q[0][1]*A[1][0] } * B[0][2]

假如Q[0][0]*A[0][0] < Q[0][1]*A[1][0],则Path[1][0]=1;假如Q[0][0]*A[0][0] > Q[0][1]*A[1][0],则Path[1][0]=0。

Q[1][1]= ……

…… ……

可以看到计算Q矩阵的每i行时都用到了第i-1行的结果。

下面给出两种算法的Java代码:

  1 import java.io.BufferedReader;
  2 import java.io.File;
  3 import java.io.FileNotFoundException;
  4 import java.io.FileReader;
  5 import java.io.IOException;
  6 import java.util.Collections;
  7 import java.util.HashMap;
  8 import java.util.LinkedList;
  9 import java.util.List;
 10 import java.util.Map;
 11 import java.util.Map.Entry;
 12 
 13 import xxzl.dm.utility.FileUtil;
 14 import xxzl.dm.utility.Pair;
 15 import xxzl.dm.utility.math.Smooth;
 16 
 17 /**
 18  * 隐马尔可夫推断,包括评估问题和解码问题
 19  * 
 20  * @Author:zhangchaoyang
 21  * @Since:2015年3月29日
 22  * @Version:1.0
 23  */
 24 public class HmmInference {
 25 
 26     /**
 27      * HMM的模型参数
 28      */
 29     private List<String> stateSet = new LinkedList<String>();// 状态值集合
 30     private List<String> observeSet = new LinkedList<String>();// 观察值集合
 31     private double[] stateProb;// 初始状态概率矩阵
 32     private double[][] stateTrans;// 状态转移矩阵
 33     private double[][] emission;// 发射矩阵
 34     private double[] minEmission;// 发射矩阵每一行的极小值(用于零值平滑)
 35 
 36     /**
 37      * 使用已标记好的训练样本,经过简单统计使用极大似然确定HMM的参数
 38      * 
 39      * @param tagFile
 40      *            文件格式:每行2列,第1列是观察值,第2列是状态值,不用序列之间用空行隔开。<br>
 41      *            /test/resources/corpus/wordcut.train是一个示例文件。
 42      */
 43     public void initParam(String tagFile) {
 44         Map<String, Integer> stateIndexMap = new HashMap<String, Integer>();// 状态值及其编号
 45         Map<String, Integer> observeIndexMap = new HashMap<String, Integer>();// 观察值及其编号
 46         int[] stateCount;// 状态值及其计数
 47         int[][] stateTransCount;// 状态转移计数矩阵
 48         int[][] confusionCount;// 混淆计数矩阵
 49 
 50         try {
 51             BufferedReader br = new BufferedReader(new FileReader(new File(
 52                     tagFile)));
 53             if (br.markSupported()) {
 54                 br.mark(1024 * 1024 * 100);
 55             }
 56             String line = null;
 57             int stateTotal = 0;
 58             int observeTotal = 0;
 59             // 第一趟扫描文件,给stateIndexMap、observeIndexMap赋值
 60             while ((line = br.readLine()) != null) {
 61                 String[] arr = line.split("\\s+");// 每行存储一个观察值、一个状态值,用空白符隔开。用空行隔开不同的序列。
 62                 if (arr.length >= 2) {
 63                     String observe = arr[0];
 64                     String state = arr[1];
 65                     if (!observeIndexMap.containsKey(observe)) {
 66                         observeIndexMap.put(observe, observeTotal++);
 67                     }
 68                     if (!stateIndexMap.containsKey(state)) {
 69                         stateIndexMap.put(state, stateTotal++);
 70                     }
 71                 }
 72             }
 73             if (br.markSupported()) {
 74                 br.reset();
 75             } else {
 76                 br.close();
 77                 br = new BufferedReader(new FileReader(new File(tagFile)));
 78             }
 79             // System.out.println("state set:");
 80             // for (Entry<String, Integer> entry : stateIndexMap.entrySet()) {
 81             // System.out.println(entry.getValue() + ":" + entry.getKey());
 82             // }
 83             // 第二趟扫描文件,给stateTransCount、confusionCount、stateCount赋值
 84             for (int i = 0; i < stateTotal; i++) {
 85                 stateSet.add("");
 86             }
 87             for (int i = 0; i < observeTotal; i++) {
 88                 observeSet.add("");
 89             }
 90             stateTransCount = new int[stateTotal][];
 91             for (int i = 0; i < stateIndexMap.size(); i++) {
 92                 stateTransCount[i] = new int[stateTotal];
 93             }
 94             confusionCount = new int[stateTotal][];
 95             for (int i = 0; i < stateTotal; i++) {
 96                 confusionCount[i] = new int[observeTotal];
 97             }
 98             stateCount = new int[stateTotal];
 99             String preState = null;
100             while ((line = br.readLine()) != null) {
101                 String[] arr = line.split("\\s+");
102                 if (arr.length >= 2) {
103                     String observe = arr[0];
104                     String state = arr[1];
105                     int row = stateIndexMap.get(state);
106                     int col = 0;
107                     int oldCount = 0;
108                     // if (observeIndexMap.containsKey(observe)) {
109                     col = observeIndexMap.get(observe);
110                     oldCount = confusionCount[row][col];
111                     confusionCount[row][col] = oldCount + 1;
112                     // }
113                     stateCount[row] = stateCount[row] + 1;
114                     if (preState == null) {
115                         preState = state;
116                     } else {
117                         row = stateIndexMap.get(preState);
118                         col = stateIndexMap.get(state);
119                         oldCount = stateTransCount[row][col];
120                         stateTransCount[row][col] = oldCount + 1;
121                         preState = state;
122                     }
123                 } else {
124                     preState = null;
125                 }
126             }
127             br.close();
128             // 给HMM基本参数赋值
129             for (Entry<String, Integer> entry : stateIndexMap.entrySet()) {
130                 String state = entry.getKey();
131                 int index = entry.getValue();
132                 stateSet.set(index, state);
133             }
134             for (Entry<String, Integer> entry : observeIndexMap.entrySet()) {
135                 String observe = entry.getKey();
136                 int index = entry.getValue();
137                 observeSet.set(index, observe);
138             }
139             stateProb = calProbByCount(Smooth.GoodTuring(stateCount));
140             // System.out.println("state initial prob:");
141             // System.out.println(Arrays.toString(stateProb));
142             stateTrans = new double[stateTransCount.length][];
143             for (int i = 0; i < stateTransCount.length; i++) {
144                 stateTrans[i] = calProbByCount(stateTransCount[i]);// 计算状态转移概率时不作平滑,因为有些状态之间转移的概率就应该是0,如果平滑就会变成非0
145             }
146             // System.out.println("state transiation prob:");
147             // for (int i = 0; i < stateTransProb.length; i++) {
148             // System.out.println(Arrays.toString(stateTransProb[i]));
149             // }
150             emission = new double[confusionCount.length][];
151             minEmission = new double[confusionCount.length];
152             for (int i = 0; i < confusionCount.length; i++) {
153                 emission[i] = calProbByCount(Smooth
154                         .GoodTuring(confusionCount[i]));
155                 double min = Double.MAX_VALUE;
156                 for (double ele : emission[i]) {
157                     if (ele < min) {
158                         min = ele;
159                     }
160                 }
161                 minEmission[i] = min;
162             }
163         } catch (FileNotFoundException e) {
164             e.printStackTrace();
165         } catch (IOException e) {
166             e.printStackTrace();
167         }
168 
169     }
170 
171     /**
172      * 采用向前算法解决评估问题:给定HMM的所有参数,评估一个观察序列出现的概率。
173      * 
174      * @param obs_seq
175      * @return
176      */
177     public double estimate(List<String> obs_seq) {
178         double rect = 0.0;
179         int LEN = obs_seq.size();
180         double[][] Q = new double[LEN][];
181         // 状态的初始概率,乘上隐藏状态到观察状态的条件概率。
182         Q[0] = new double[stateSet.size()];
183         for (int j = 0; j < stateSet.size(); j++) {
184             if (observeSet.contains(obs_seq.get(0))) {
185                 Q[0][j] = stateProb[j]
186                         * emission[j][observeSet.indexOf(obs_seq.get(0))];
187             } else {
188                 Q[0][j] = stateProb[j] * minEmission[j];
189                 System.err.println("观察值'" + obs_seq.get(0) + "'在已标记样本中未出现过");
190             }
191         }
192         // 首先从前一时刻的每个状态,转移到当前状态的概率求和,然后乘上隐藏状态到观察状态的条件概率。
193         for (int i = 1; i < LEN; i++) {
194             Q[i] = new double[stateSet.size()];
195             for (int j = 0; j < stateSet.size(); j++) {
196                 double sum = 0.0;
197                 for (int k = 0; k < stateSet.size(); k++) {
198                     sum += Q[i - 1][k] * stateTrans[k][j];
199                 }
200                 if (observeSet.contains(obs_seq.get(i))) {
201                     Q[i][j] = sum
202                             * emission[j][observeSet.indexOf(obs_seq.get(i))];
203                 } else {
204                     Q[i][j] = sum * minEmission[j];
205                     System.err
206                             .println("观察值'" + obs_seq.get(0) + "'在已标记样本中未出现过");
207                 }
208             }
209         }
210         for (int i = 0; i < stateSet.size(); i++)
211             rect += Q[LEN - 1][i];
212         return rect;
213     }
214 
215     /**
216      * 采用viterbi进行解码:给定HMM的所有参数,给一个观察序列,评估最可能的状态序列是什么。
217      * 
218      * @param observe
219      * @return
220      */
221     public Pair<Double, LinkedList<String>> viterbi(List<String> observe) {
222         LinkedList<String> sta = new LinkedList<String>();
223         int LEN = observe.size();
224         int M = stateSet.size();
225         double[][] Q = new double[LEN][];
226         int[][] Path = new int[LEN][];
227         Q[0] = new double[M];
228         Path[0] = new int[M];
229         for (int j = 0; j < M; j++) {
230             if (observeSet.contains(observe.get(0))) {// 观察值在训练样本中未出现过,则概率设为0
231                 Q[0][j] = stateProb[j]
232                         * emission[j][observeSet.indexOf(observe.get(0))];
233             } else {
234                 Q[0][j] = stateProb[j] * minEmission[j] / 2;
235                 System.err.println("观察值'" + observe.get(0) + "'在已标记样本中未出现过");
236             }
237             Path[0][j] = -1;
238         }
239         for (int i = 1; i < LEN; i++) {
240             Q[i] = new double[M];
241             Path[i] = new int[M];
242             for (int j = 0; j < M; j++) {
243                 double max = 0.0;
244                 int index = 0;
245                 for (int k = 0; k < M; k++) {
246                     if (Q[i - 1][k] * stateTrans[k][j] > max) {
247                         max = Q[i - 1][k] * stateTrans[k][j];
248                         index = k;
249                     }
250                 }
251                 if (observeSet.contains(observe.get(i))) {
252                     Q[i][j] = max
253                             * emission[j][observeSet.indexOf(observe.get(i))];
254                 } else {
255                     Q[i][j] = max * minEmission[j] / 2;
256                     System.err
257                             .println("观察值'" + observe.get(0) + "'在已标记样本中未出现过");
258                 }
259                 Path[i][j] = index;
260             }
261         }
262         // 找到最后一个时刻呈现哪种状态的概率最大
263         double max = 0;
264         int index = 0;
265         for (int i = 0; i < M; i++) {
266             if (Q[LEN - 1][i] > max) {
267                 max = Q[LEN - 1][i];
268                 index = i;
269             }
270         }
271         sta.add(stateSet.get(index));
272         // 动态规划,逆推回去各个时刻出现什么状态概率最大
273         for (int i = LEN - 1; i > 0; i--) {
274             index = Path[i][index];
275             sta.add(stateSet.get(index));
276         }
277         // 把状态序列再顺过来
278         Collections.reverse(sta);
279         return Pair.of(max, sta);
280     }
281 
282     public void baumWelch() {
283 
284     }
285 
286     public void initStateSet(String infile) {
287         FileUtil.readLines(infile, stateSet);
288     }
289 
290     public void initObserveSet(String infile) {
291         FileUtil.readLines(infile, observeSet);
292     }
293 
294     public void initStateProb(String infile) {
295         assert stateSet != null;
296         int stateCount = stateSet.size();
297         assert stateCount > 0;
298         stateProb = new double[stateCount];
299         List<String> lines = new LinkedList<String>();
300         FileUtil.readLines(infile, lines);
301         assert lines.size() >= stateCount;
302         for (int i = 0; i < stateCount; i++) {
303             stateProb[i] = Double.parseDouble(lines.get(i));
304         }
305     }
306 
307     public void initStateTrans(String infile) {
308         List<String> lines = new LinkedList<String>();
309         FileUtil.readLines(infile, lines);
310         stateTrans = new double[lines.size()][];
311         for (int i = 0; i < lines.size(); i++) {
312             String[] conts = lines.get(i).split("\\s+");
313             stateTrans[i] = new double[conts.length];
314             for (int j = 0; j < stateTrans[i].length; j++) {
315                 stateTrans[i][j] = Double.parseDouble(conts[j]);
316             }
317         }
318     }
319 
320     public void initConfusion(String infile) {
321         List<String> lines = new LinkedList<String>();
322         FileUtil.readLines(infile, lines);
323         emission = new double[lines.size()][];
324         minEmission = new double[lines.size()];
325         for (int i = 0; i < lines.size(); i++) {
326             String[] conts = lines.get(i).split("\\s+");
327             emission[i] = new double[conts.length];
328             double min = Double.MAX_VALUE;
329             for (int j = 0; j < emission[i].length; j++) {
330                 double ele = Double.parseDouble(conts[j]);
331                 emission[i][j] = ele;
332                 if (ele < min) {
333                     min = ele;
334                 }
335             }
336             minEmission[i] = min;
337         }
338     }
339 
340     public List<String> getStateSet() {
341         return stateSet;
342     }
343 
344     public void setStateSet(List<String> stateSet) {
345         this.stateSet = stateSet;
346     }
347 
348     public List<String> getObserveSet() {
349         return observeSet;
350     }
351 
352     public void setObserveSet(List<String> observeSet) {
353         this.observeSet = observeSet;
354     }
355 
356     public double[] getStateProb() {
357         return stateProb;
358     }
359 
360     public void setStateProb(double[] stateProb) {
361         this.stateProb = stateProb;
362     }
363 
364     public double[][] getStateTrans() {
365         return stateTrans;
366     }
367 
368     public void setStateTrans(double[][] stateTrans) {
369         this.stateTrans = stateTrans;
370     }
371 
372     public double[][] getEmission() {
373         return emission;
374     }
375 
376     public void setEmission(double[][] emission) {
377         this.emission = emission;
378         minEmission = new double[emission.length];
379         for (int i = 0; i < emission.length; i++) {
380             double min = Double.MAX_VALUE;
381             for (double ele : emission[i]) {
382                 if (ele < min) {
383                     min = ele;
384                 }
385             }
386             minEmission[i] = min;
387         }
388     }
389 
390     /**
391      * 通过一组计数计算概率
392      * 
393      * @param countArr
394      * @return
395      */
396     private double[] calProbByCount(double[] countArr) {
397         double sum = 0.0;
398         for (double count : countArr) {
399             sum += count;
400         }
401         double[] prob = new double[countArr.length];
402         for (int i = 0; i < countArr.length; i++) {
403             prob[i] = countArr[i] / sum;
404         }
405         return prob;
406     }
407 
408     /**
409      * 通过一组计数计算概率
410      * 
411      * @param countArr
412      * @return
413      */
414     private double[] calProbByCount(int[] countArr) {
415         double sum = 0.0;
416         for (double count : countArr) {
417             sum += count;
418         }
419         double[] prob = new double[countArr.length];
420         for (int i = 0; i < countArr.length; i++) {
421             prob[i] = countArr[i] / sum;
422         }
423         return prob;
424     }
425 }
View Code

 测试代码:

 1 import java.util.ArrayList;
 2 import java.util.LinkedList;
 3 import java.util.List;
 4 
 5 import org.junit.BeforeClass;
 6 import org.junit.Test;
 7 
 8 import xxzl.dm.core.sequence.HmmInference;
 9 import xxzl.dm.utility.Pair;
10 
11 public class TestHmmInference {
12 
13     private static HmmInference hmm = new HmmInference();
14 
15     @BeforeClass
16     public static void setup() {
17         hmm.initParam("/Users/zhangchaoyang/OneDrive/msr_train.corp");
18     }
19 
20     @Test
21     public void testForward() {
22         String sentence = "公布这些事实将对日记主人及其他被涉及者带来何种影响";
23         List<String> obs_seq = str2List(sentence);
24         double ratio = hmm.estimate(obs_seq);
25         System.out.println("观察序列出现的概率:" + ratio);
26     }
27 
28     @Test
29     public void testViterbi() {
30         String sentence = "公布这些事实将对日记主人及其他被涉及者带来何种影响";
31         List<String> obs_seq = str2List(sentence);
32         Pair<Double, LinkedList<String>> states = hmm.viterbi(obs_seq);
33         System.out.println("最可能的状态序列:" + states.second + ",其概率为:"
34                 + states.first);
35         System.out.println("分词结果为:" + wordSeg(sentence, states.second));
36     }
37 
38     private List<String> str2List(String sentence) {
39         List<String> obs_seq = new ArrayList<String>();
40         int i = 0;
41         while (i < sentence.length()) {
42             obs_seq.add(sentence.substring(i, i + 1));
43             i++;
44         }
45         return obs_seq;
46     }
47 
48     private String wordSeg(String sentence, List<String> tag) {
49         StringBuilder sb = new StringBuilder();
50         int len = sentence.length() <= tag.size() ? sentence.length() : tag
51                 .size();
52         for (int i = 0; i < len; i++) {
53             String word = sentence.substring(i, i + 1);
54             sb.append(word);
55             if (tag.get(i).equals("E") || tag.get(i).equals("S")) {
56                 sb.append("\t");
57             }
58         }
59         return sb.toString();
60     }
61 }
View Code

 

posted @ 2011-10-20 22:10  张朝阳  阅读(8557)  评论(3编辑  收藏  举报