代码改变世界

学习OpenCV——SVM 手写数字检测

2016-04-17 19:46  GarfieldEr007  阅读(378)  评论(0编辑  收藏  举报

转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

 

 

使用OPENCV训练手写数字识别分类器 

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

 

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

 

    1. #include "stdafx.h"   
    2.   
    3. #include <fstream>   
    4. #include "opencv2/opencv.hpp"   
    5. #include <vector>   
    6.   
    7. using namespace std;  
    8. using namespace cv;  
    9.   
    10. #define SHOW_PROCESS 0   
    11. #define ON_STUDY 0   
    12.   
    13. class NumTrainData  
    14. {  
    15. public:  
    16.     NumTrainData()  
    17.     {  
    18.         memset(data, 0, sizeof(data));  
    19.         result = -1;  
    20.     }  
    21. public:  
    22.     float data[64];  
    23.     int result;  
    24. };  
    25.   
    26. vector<NumTrainData> buffer;  
    27. int featureLen = 64;  
    28.   
    29. void swapBuffer(char* buf)  
    30. {  
    31.     char temp;  
    32.     temp = *(buf);  
    33.     *buf = *(buf+3);  
    34.     *(buf+3) = temp;  
    35.   
    36.     temp = *(buf+1);  
    37.     *(buf+1) = *(buf+2);  
    38.     *(buf+2) = temp;  
    39. }  
    40.   
    41. void GetROI(Mat& src, Mat& dst)  
    42. {  
    43.     int left, right, top, bottom;  
    44.     left = src.cols;  
    45.     right = 0;  
    46.     top = src.rows;  
    47.     bottom = 0;  
    48.   
    49.     //Get valid area   
    50.     for(int i=0; i<src.rows; i++)  
    51.     {  
    52.         for(int j=0; j<src.cols; j++)  
    53.         {  
    54.             if(src.at<uchar>(i, j) > 0)  
    55.             {  
    56.                 if(j<left) left = j;  
    57.                 if(j>right) right = j;  
    58.                 if(i<top) top = i;  
    59.                 if(i>bottom) bottom = i;  
    60.             }  
    61.         }  
    62.     }  
    63.   
    64.     //Point center;   
    65.     //center.x = (left + right) / 2;   
    66.     //center.y = (top + bottom) / 2;   
    67.   
    68.     int width = right - left;  
    69.     int height = bottom - top;  
    70.     int len = (width < height) ? height : width;  
    71.   
    72.     //Create a squre   
    73.     dst = Mat::zeros(len, len, CV_8UC1);  
    74.   
    75.     //Copy valid data to squre center   
    76.     Rect dstRect((len - width)/2, (len - height)/2, width, height);  
    77.     Rect srcRect(left, top, width, height);  
    78.     Mat dstROI = dst(dstRect);  
    79.     Mat srcROI = src(srcRect);  
    80.     srcROI.copyTo(dstROI);  
    81. }  
    82.   
    83. int ReadTrainData(int maxCount)  
    84. {  
    85.     //Open image and label file   
    86.     const char fileName[] = "../res/train-images.idx3-ubyte";  
    87.     const char labelFileName[] = "../res/train-labels.idx1-ubyte";  
    88.   
    89.     ifstream lab_ifs(labelFileName, ios_base::binary);  
    90.     ifstream ifs(fileName, ios_base::binary);  
    91.   
    92.     if( ifs.fail() == true )  
    93.         return -1;  
    94.   
    95.     if( lab_ifs.fail() == true )  
    96.         return -1;  
    97.   
    98.     //Read train data number and image rows / cols   
    99.     char magicNum[4], ccount[4], crows[4], ccols[4];  
    100.     ifs.read(magicNum, sizeof(magicNum));  
    101.     ifs.read(ccount, sizeof(ccount));  
    102.     ifs.read(crows, sizeof(crows));  
    103.     ifs.read(ccols, sizeof(ccols));  
    104.   
    105.     int count, rows, cols;  
    106.     swapBuffer(ccount);  
    107.     swapBuffer(crows);  
    108.     swapBuffer(ccols);  
    109.   
    110.     memcpy(&count, ccount, sizeof(count));  
    111.     memcpy(&rows, crows, sizeof(rows));  
    112.     memcpy(&cols, ccols, sizeof(cols));  
    113.   
    114.     //Just skip label header   
    115.     lab_ifs.read(magicNum, sizeof(magicNum));  
    116.     lab_ifs.read(ccount, sizeof(ccount));  
    117.   
    118.     //Create source and show image matrix   
    119.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
    120.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
    121.     Mat img, dst;  
    122.   
    123.     char label = 0;  
    124.     Scalar templateColor(255, 0, 255 );  
    125.   
    126.     NumTrainData rtd;  
    127.   
    128.     //int loop = 1000;   
    129.     int total = 0;  
    130.   
    131.     while(!ifs.eof())  
    132.     {  
    133.         if(total >= count)  
    134.             break;  
    135.           
    136.         total++;  
    137.         cout << total << endl;  
    138.           
    139.         //Read label   
    140.         lab_ifs.read(&label, 1);  
    141.         label = label + '0';  
    142.   
    143.         //Read source data   
    144.         ifs.read((char*)src.data, rows * cols);  
    145.         GetROI(src, dst);  
    146.   
    147. #if(SHOW_PROCESS)   
    148.         //Too small to watch   
    149.         img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);  
    150.         resize(dst, img, img.size());  
    151.   
    152.         stringstream ss;  
    153.         ss << "Number " << label;  
    154.         string text = ss.str();  
    155.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
    156.   
    157.         //imshow("img", img);   
    158. #endif   
    159.   
    160.         rtd.result = label;  
    161.         resize(dst, temp, temp.size());  
    162.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
    163.   
    164.         for(int i = 0; i<8; i++)  
    165.         {  
    166.             for(int j = 0; j<8; j++)  
    167.             {  
    168.                     rtd.data[ i*8 + j] = temp.at<uchar>(i, j);  
    169.             }  
    170.         }  
    171.   
    172.         buffer.push_back(rtd);  
    173.   
    174.         //if(waitKey(0)==27) //ESC to quit   
    175.         //  break;   
    176.   
    177.         maxCount--;  
    178.           
    179.         if(maxCount == 0)  
    180.             break;  
    181.     }  
    182.   
    183.     ifs.close();  
    184.     lab_ifs.close();  
    185.   
    186.     return 0;  
    187. }  
    188.   
    189. void newRtStudy(vector<NumTrainData>& trainData)  
    190. {  
    191.     int testCount = trainData.size();  
    192.   
    193.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
    194.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
    195.   
    196.     for (int i= 0; i< testCount; i++)   
    197.     {   
    198.   
    199.         NumTrainData td = trainData.at(i);  
    200.         memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));  
    201.   
    202.         res.at<unsigned int>(i, 0) = td.result;  
    203.     }  
    204.   
    205.     /////////////START RT TRAINNING//////////////////   
    206.     CvRTrees forest;  
    207.     CvMat* var_importance = 0;  
    208.   
    209.     forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),  
    210.             CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));  
    211.     forest.save( "new_rtrees.xml" );  
    212. }  
    213.   
    214.   
    215. int newRtPredict()  
    216. {  
    217.     CvRTrees forest;  
    218.     forest.load( "new_rtrees.xml" );  
    219.   
    220.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
    221.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
    222.   
    223.     ifstream lab_ifs(labelFileName, ios_base::binary);  
    224.     ifstream ifs(fileName, ios_base::binary);  
    225.   
    226.     if( ifs.fail() == true )  
    227.         return -1;  
    228.   
    229.     if( lab_ifs.fail() == true )  
    230.         return -1;  
    231.   
    232.     char magicNum[4], ccount[4], crows[4], ccols[4];  
    233.     ifs.read(magicNum, sizeof(magicNum));  
    234.     ifs.read(ccount, sizeof(ccount));  
    235.     ifs.read(crows, sizeof(crows));  
    236.     ifs.read(ccols, sizeof(ccols));  
    237.   
    238.     int count, rows, cols;  
    239.     swapBuffer(ccount);  
    240.     swapBuffer(crows);  
    241.     swapBuffer(ccols);  
    242.   
    243.     memcpy(&count, ccount, sizeof(count));  
    244.     memcpy(&rows, crows, sizeof(rows));  
    245.     memcpy(&cols, ccols, sizeof(cols));  
    246.   
    247.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
    248.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
    249.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
    250.     Mat img, dst;  
    251.   
    252.     //Just skip label header   
    253.     lab_ifs.read(magicNum, sizeof(magicNum));  
    254.     lab_ifs.read(ccount, sizeof(ccount));  
    255.   
    256.     char label = 0;  
    257.     Scalar templateColor(255, 0, 0);  
    258.   
    259.     NumTrainData rtd;  
    260.   
    261.     int right = 0, error = 0, total = 0;  
    262.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
    263.     while(ifs.good())  
    264.     {  
    265.         //Read label   
    266.         lab_ifs.read(&label, 1);  
    267.         label = label + '0';  
    268.   
    269.         //Read data   
    270.         ifs.read((char*)src.data, rows * cols);  
    271.         GetROI(src, dst);  
    272.   
    273.         //Too small to watch   
    274.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
    275.         resize(dst, img, img.size());  
    276.   
    277.         rtd.result = label;  
    278.         resize(dst, temp, temp.size());  
    279.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
    280.         for(int i = 0; i<8; i++)  
    281.         {  
    282.             for(int j = 0; j<8; j++)  
    283.             {  
    284.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
    285.             }  
    286.         }  
    287.   
    288.         if(total >= count)  
    289.             break;  
    290.   
    291.         char ret = (char)forest.predict(m);   
    292.   
    293.         if(ret == label)  
    294.         {  
    295.             right++;  
    296.             if(total <= 5000)  
    297.                 right_1++;  
    298.             else  
    299.                 right_2++;  
    300.         }  
    301.         else  
    302.         {  
    303.             error++;  
    304.             if(total <= 5000)  
    305.                 error_1++;  
    306.             else  
    307.                 error_2++;  
    308.         }  
    309.   
    310.         total++;  
    311.   
    312. #if(SHOW_PROCESS)   
    313.         stringstream ss;  
    314.         ss << "Number " << label << ", predict " << ret;  
    315.         string text = ss.str();  
    316.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
    317.   
    318.         imshow("img", img);  
    319.         if(waitKey(0)==27) //ESC to quit   
    320.             break;  
    321. #endif   
    322.   
    323.     }  
    324.   
    325.     ifs.close();  
    326.     lab_ifs.close();  
    327.   
    328.     stringstream ss;  
    329.     ss << "Total " << total << ", right " << right <<", error " << error;  
    330.     string text = ss.str();  
    331.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
    332.     imshow("img", img);  
    333.     waitKey(0);  
    334.   
    335.     return 0;  
    336. }  
    337.   
    338. void newSvmStudy(vector<NumTrainData>& trainData)  
    339. {  
    340.     int testCount = trainData.size();  
    341.   
    342.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
    343.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
    344.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
    345.   
    346.     for (int i= 0; i< testCount; i++)   
    347.     {   
    348.   
    349.         NumTrainData td = trainData.at(i);  
    350.         memcpy(m.data, td.data, featureLen*sizeof(float));  
    351.         normalize(m, m);  
    352.         memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));  
    353.   
    354.         res.at<unsigned int>(i, 0) = td.result;  
    355.     }  
    356.   
    357.     /////////////START SVM TRAINNING//////////////////   
    358.     CvSVM svm = CvSVM();   
    359.     CvSVMParams param;   
    360.     CvTermCriteria criteria;  
    361.   
    362.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);   
    363.     param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
    364.   
    365.     svm.train(data, res, Mat(), Mat(), param);  
    366.     svm.save( "SVM_DATA.xml" );  
    367. }  
    368.   
    369.   
    370. int newSvmPredict()  
    371. {  
    372.     CvSVM svm = CvSVM();   
    373.     svm.load( "SVM_DATA.xml" );  
    374.   
    375.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
    376.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
    377.   
    378.     ifstream lab_ifs(labelFileName, ios_base::binary);  
    379.     ifstream ifs(fileName, ios_base::binary);  
    380.   
    381.     if( ifs.fail() == true )  
    382.         return -1;  
    383.   
    384.     if( lab_ifs.fail() == true )  
    385.         return -1;  
    386.   
    387.     char magicNum[4], ccount[4], crows[4], ccols[4];  
    388.     ifs.read(magicNum, sizeof(magicNum));  
    389.     ifs.read(ccount, sizeof(ccount));  
    390.     ifs.read(crows, sizeof(crows));  
    391.     ifs.read(ccols, sizeof(ccols));  
    392.   
    393.     int count, rows, cols;  
    394.     swapBuffer(ccount);  
    395.     swapBuffer(crows);  
    396.     swapBuffer(ccols);  
    397.   
    398.     memcpy(&count, ccount, sizeof(count));  
    399.     memcpy(&rows, crows, sizeof(rows));  
    400.     memcpy(&cols, ccols, sizeof(cols));  
    401.   
    402.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
    403.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
    404.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
    405.     Mat img, dst;  
    406.   
    407.     //Just skip label header   
    408.     lab_ifs.read(magicNum, sizeof(magicNum));  
    409.     lab_ifs.read(ccount, sizeof(ccount));  
    410.   
    411.     char label = 0;  
    412.     Scalar templateColor(255, 0, 0);  
    413.   
    414.     NumTrainData rtd;  
    415.   
    416.     int right = 0, error = 0, total = 0;  
    417.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
    418.     while(ifs.good())  
    419.     {  
    420.         //Read label   
    421.         lab_ifs.read(&label, 1);  
    422.         label = label + '0';  
    423.   
    424.         //Read data   
    425.         ifs.read((char*)src.data, rows * cols);  
    426.         GetROI(src, dst);  
    427.   
    428.         //Too small to watch   
    429.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
    430.         resize(dst, img, img.size());  
    431.   
    432.         rtd.result = label;  
    433.         resize(dst, temp, temp.size());  
    434.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
    435.         for(int i = 0; i<8; i++)  
    436.         {  
    437.             for(int j = 0; j<8; j++)  
    438.             {  
    439.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
    440.             }  
    441.         }  
    442.   
    443.         if(total >= count)  
    444.             break;  
    445.   
    446.         normalize(m, m);  
    447.         char ret = (char)svm.predict(m);   
    448.   
    449.         if(ret == label)  
    450.         {  
    451.             right++;  
    452.             if(total <= 5000)  
    453.                 right_1++;  
    454.             else  
    455.                 right_2++;  
    456.         }  
    457.         else  
    458.         {  
    459.             error++;  
    460.             if(total <= 5000)  
    461.                 error_1++;  
    462.             else  
    463.                 error_2++;  
    464.         }  
    465.   
    466.         total++;  
    467.   
    468. #if(SHOW_PROCESS)   
    469.         stringstream ss;  
    470.         ss << "Number " << label << ", predict " << ret;  
    471.         string text = ss.str();  
    472.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
    473.   
    474.         imshow("img", img);  
    475.         if(waitKey(0)==27) //ESC to quit   
    476.             break;  
    477. #endif   
    478.   
    479.     }  
    480.   
    481.     ifs.close();  
    482.     lab_ifs.close();  
    483.   
    484.     stringstream ss;  
    485.     ss << "Total " << total << ", right " << right <<", error " << error;  
    486.     string text = ss.str();  
    487.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
    488.     imshow("img", img);  
    489.     waitKey(0);  
    490.   
    491.     return 0;  
    492. }  
    493.   
    494. int main( int argc, char *argv[] )  
    495. {  
    496. #if(ON_STUDY)   
    497.     int maxCount = 60000;  
    498.     ReadTrainData(maxCount);  
    499.   
    500.     //newRtStudy(buffer);   
    501.     newSvmStudy(buffer);  
    502. #else   
    503.     //newRtPredict();   
    504.     newSvmPredict();  
    505. #endif   
    506.     return 0;  
    507. }
    508. //from: http://blog.csdn.net/yangtrees/article/details/7458466