基于朴素贝叶斯的扑克牌花色识别
本程序只对扑克牌的花色进行训练和识别,对扑克牌上的数字的识别在以后的学习中再进行完善。
本次只是简单的提取了扑克牌的RGB均值、HSV均值、7 个不变矩以及长宽比等14个简单的特征,其中,长宽比为了防止图像的位置等因素的影响,提取了目标区域的最小外接矩形。
部分图像如下图所示:
特征提取的部分代码如下所示:
[cpp] view plain copy
1.void CPokeAlgorithmDlg::CollectCharacter(IplImage* img, CvMat* mat, int rows)
2.{
3. if (img != nullptr)
4. {
5. showImage(img, IDC_PIC1); //显示图像
6.
7. IplImage* bitImage = nullptr, *grayImage = nullptr, *hsvImage = nullptr;
8.
9. bitImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1);
10. grayImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 1);
11. hsvImage = cvCreateImage(cvGetSize(img), IPL_DEPTH_8U, 3);
12.
13.
14. cvCvtColor(img, hsvImage, CV_RGB2HSV);
15. cvCvtColor(img, grayImage, CV_RGB2GRAY);
16.
17. cvSmooth(grayImage, grayImage, CV_MEDIAN);
18. cvThreshold(grayImage, bitImage, 128, 255.0, CV_THRESH_BINARY);
19.
20. cvNot(bitImage, bitImage);
21.
22. IplConvKernel* element = cvCreateStructuringElementEx(5, 5, 2, 2, CV_SHAPE_ELLIPSE);
23. cvSmooth(bitImage, bitImage, CV_MEDIAN);
24. cvErode(bitImage, bitImage, element, 1);
25. cvDilate(bitImage, bitImage, element, 1);
26. cvReleaseStructuringElement(&element);
27. element = NULL;
28.
29.
30. CvMemStorage* storage = cvCreateMemStorage(0);
31. CvSeq* contour = 0;
32. cvFindContours(bitImage, storage, &contour, sizeof(CvContour), CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE); //轮廓检索
33.
34. for (; contour != 0; contour = contour->h_next)
35. {
36. double area = fabs(cvContourArea(contour, CV_WHOLE_SEQ));
37.
38. if (area > 2000) //此处阈值需重新调节
39. {
40. cvDrawContours(bitImage, contour, cvScalarAll(255), cvScalarAll(255), -1, CV_FILLED, 8);
41. CvRect rect = cvBoundingRect(contour, 0);
42.
43. CvBox2D minRect = cvMinAreaRect2(contour, storage);
44.
45. CvPoint2D32f rectPts[4] = { 0 };
46. cvBoxPoints(minRect, rectPts);
47. int nPts = 4; // 4 个顶点
48.
49. CvPoint minRectPts[4] = { 0 };
50. for (int i = 0; i < 4; ++i)
51. {
52. minRectPts[i] = cvPointFrom32f(rectPts[i]); //将 cvPoint2D32f 转化为 CvPoint
53. }
54. CvPoint *pt = minRectPts;
55.
56. //在图像中绘制矩形框
57. cvPolyLine(bitImage, &pt, &nPts, 1, 1, cvScalarAll(255), 1);
58.
59. int l1 = sqrtf((pt[0].x - pt[1].x)*(pt[0].x - pt[1].x) + (pt[0].y - pt[1].y)*(pt[0].y - pt[1].y));
60. int l2 = sqrtf((pt[2].x - pt[1].x)*(pt[2].x - pt[1].x) + (pt[2].y - pt[1].y)*(pt[2].y - pt[1].y));
61.
62. int length = l1 > l2 ? l1 : l2; //取较长边为图形的长
63. int width = l1 > l2 ? l2 : l1; //取较短边为图形的宽
64.
65. double r = (width * 1.0) / length; //长宽比
66.
67. cvSetReal2D(mat, rows, 0, r);
68.
69. double RMean = 0, GMean = 0, BMean = 0;
70. double HMean = 0, SMean = 0, VMean = 0;
71. int nCount = 0;
72.
73. for (int imgRow = rect.y; imgRow < rect.y + rect.height; ++imgRow)
74. {
75. for (int imgCol = rect.x; imgCol < rect.x + rect.width; ++imgCol)
76. {
77. CvScalar s = cvGet2D(bitImage, imgRow, imgCol);
78.
79. if (s.val[0] == 255)
80. {
81. s = cvGet2D(img, imgRow, imgCol);
82. RMean += s.val[2];
83. GMean += s.val[1];
84. BMean += s.val[0];
85.
86. s = cvGet2D(hsvImage, imgRow, imgCol);
87. HMean += s.val[0];
88. SMean += s.val[1];
89. VMean += s.val[2];
90.
91. ++nCount;
92. }
93. }
94. }// end RGB,HSV for
95.
96. RMean /= nCount;
97. GMean /= nCount;
98. BMean /= nCount;
99.
100. HMean /= nCount;
101. SMean /= nCount;
102. VMean /= nCount;
103.
104.
105. cvSetReal2D(mat, rows, 1, RMean);
106. cvSetReal2D(mat, rows, 2, GMean);
107. cvSetReal2D(mat, rows, 3, BMean);
108. cvSetReal2D(mat, rows, 4, HMean);
109. cvSetReal2D(mat, rows, 5, SMean);
110. cvSetReal2D(mat, rows, 6, VMean);
111.
112. //7个不变矩
113.
114. CvMoments moments;
115. cvMoments(contour, &moments, 1);
116. CvHuMoments huMoments;
117. cvGetHuMoments(&moments, &huMoments);
118.
119. double hu1 = huMoments.hu1;
120. double hu2 = huMoments.hu2;
121. double hu3 = huMoments.hu3;
122. double hu4 = huMoments.hu4;
123. double hu5 = huMoments.hu5;
124. double hu6 = huMoments.hu6;
125. double hu7 = huMoments.hu7;
126.
127. cvSetReal2D(mat, rows, 7, hu1);
128. cvSetReal2D(mat, rows, 8, hu2);
129. cvSetReal2D(mat, rows, 9, hu3);
130. cvSetReal2D(mat, rows, 10, hu4);
131. cvSetReal2D(mat, rows, 11, hu5);
132. cvSetReal2D(mat, rows, 12, hu6);
133. cvSetReal2D(mat, rows, 13, hu7);
134. }// end if
135. }
136.
137. showImage(hsvImage, IDC_PIC3);
138. showImage(bitImage, IDC_PIC2);
139.
140.
141. //释放内存
142. cvReleaseMemStorage(&storage);
143. storage = nullptr;
144. cvReleaseImage(&bitImage);
145. bitImage = nullptr;
146. cvReleaseImage(&grayImage);
147. grayImage = nullptr;
148. cvReleaseImage(&hsvImage);
149. hsvImage = nullptr;
150. }
151.
152. //释放内存
153. cvReleaseImage(&img);
154. img = nullptr;
155.}
Bayes训练代码:
[cpp] view plain copy
1.Book* book = xlCreateXMLBookW();
2.
3. CvMat* dataMat = NULL;
4.
5. if (book->load(L"Data.xlsx"))
6. {
7. Sheet *sheet = book->getSheet(0);
8.
9. int myrow = sheet->lastRow();
10. int mycol = sheet->lastCol();
11.
12. if (sheet)
13. {
14. CvMat* importMat = cvCreateMat(myrow, mycol, CV_32FC1); //存储导入数据
15.
16. for (auto i = 0; i < myrow; ++i)
17. {
18. for (auto j = 0; j < mycol; j++)
19. {
20. double temp = sheet->readNum(i, j);
21. cvSetReal2D(importMat, i, j, temp);
22. }
23. }// end for
24.
25. dataMat = cvCloneMat(importMat);
26. }// end if
27. }
28.
29. book->release();
30.
31. MessageBox(L"数据导入完成");
32.
33. CvMat* lableMat = cvCreateMat(dataMat->rows, 1, CV_32FC1); //构建样本的分类标签
34. cvZero(lableMat);
35.
36. for (int i = 0; i < 4; ++i) //共分了 20 个不同的种类
37. {
38. for (int j = 0; j < 10; ++j) //每个品种共50个籽粒
39. {
40. cvSetReal2D(lableMat, i * 10 + j, 0, i + 1);
41. }
42. }
43.
44. CvNormalBayesClassifier nbc;
45. nbc.train(dataMat, lableMat);
46. nbc.save("bayes.txt");
47.
48. MessageBox(L"数据训练完成");
49.
50. CvMat* nbcResult = cvCreateMat(dataMat->rows, 1, CV_32FC1);
51. CvMat* nbcRow = NULL;
52.
53. for (int i = 0; i < dataMat->rows; ++i)
54. {
55. nbcRow = cvCreateMat(1, dataMat->cols, CV_32FC1);
56.
57. for (int j = 0; j < dataMat->cols; ++j)
58. {
59. float temp = cvGetReal2D(dataMat, i, j);
60. cvSetReal2D(nbcRow, 0, j, temp);
61. }
62.
63. unsigned int ret = 0;
64. ret = nbc.predict(nbcRow);
65. cvSetReal2D(nbcResult, i, 0, ret);
66. cvReleaseMat(&nbcRow);
67. nbcRow = NULL;
68. }
69.
70. int nCount = 0;
71.
72. for (int i = 0; i < 4; ++i)
73. {
74. for (int j = 0; j < 10; ++j)
75. {
76. int ret = cvGetReal2D(nbcResult, i * 10 + j, 0);
77. if (ret == (i + 1))
78. {
79. ++nCount;
80. }
81. }
82. }
83.
84. float recognize = 100 * nCount / 10 / 4;
85.
86. CString str;
87. str.Format(L"朴素贝叶斯 识别率为: %f", recognize);
88. str = str + L"%";
89. MessageBox(str);
90.
识别代码如下所示:
[html] view plain copy
1.CvNormalBayesClassifier nbc;
2. nbc.load("bayes.txt");
3.
4. CFileDialog dlg(TRUE, NULL, NULL, 0, L"图片文件(*.jpg)|*.jpg||");
5. if (dlg.DoModal() == IDOK)
6. {
7. USES_CONVERSION;
8. const char* loadPath = W2A(dlg.GetPathName());
9. IplImage* testImage = cvLoadImage(loadPath);
10.
11. CvMat* mat = cvCreateMat(1, 14, CV_32FC1);
12. CollectCharacter(testImage, mat, 0);
13.
14. int ret = nbc.predict(mat);
15. CString str;
16. switch (ret)
17. {
18. case 1:
19. str = L"黑桃";
20. break;
21. case 2:
22. str = "红桃";
23. break;
24. case 3:
25. str = "梅花";
26. break;
27. case 4:
28. str = "方块";
29. break;
30. }
31. AfxMessageBox(str);
32. cvReleaseMat(&mat);
33. mat = NULL;
34. }//end if

浙公网安备 33010602011771号