浅墨浓香

想要天亮进城,就得天黑赶路。

导航

简单的手写数字识别系统(0-9)

Posted on 2017-02-25 16:11  浅墨浓香  阅读(1501)  评论(0)    收藏  举报

1. 机器学习简介

(1)让计算机像人类一样具备学习能力,从而更好的解决问题。

(2)机器学习的本质

  ①对己有数据进行处理(获取知识)

  ②对大量数据进行重复的处理(技能求精)

  ③对未知数据进行决策(知识应用)。如购物推荐系统,可以根据购物数据,分析客户喜好,再进行针对性的商品推荐

(3)机器学习与人工智能:机器学习是人工智能领域的一个重要分支

 

2. 机器学习的主要类型

(1)有监督的学习:在人工干预下对己知数据进行学习

(2)无监督的学习

  ①无人工干预的对数据进行学习

  ②人工预定义的激励规则(用于判断学习是否正确)

  ③机器通过激励规则逐渐逼近目标收益。

3. 机器学习的应用领域

(1)计算机视觉:人脸识别、指纹识别、自动驾驶等

(2)语言处理:语言识别、语言处理、语言理解等

(3)大数据分析:天气预报、环境监测、客户挖掘等。

4. 机器学习案例:最简单的手写体识别(识别手写体数字0-9

(1)特征抽取

 

(2)特征向量的存储(知识学习)

 

(3)识别计算

  ①用学习时(第一步)的方法抽取未知数字的特征向量

  ②计算己知数字特征向量未知数字特征向量的空间距离。

  ③选择空间距离最近的特征向量所对应的数字作为结果输出。

【附录】Qt中利用QODBC访问Access数据库

(1)我的Qt环境:

  ①Qt源码存放位置: D:\Qt\5.6\src

  ②Qt Creator的位置:D:\Qt\5.6\mingw49_32

(2)安装ODBC驱动

  ①Step1:cd D:\Qt\5.6\Src\qtbase\src\plugins\sqldrivers\odbc

  ②Step2:qmake odbc.pro

  ③Step3:mingw32-make

  ④Step4:将D:\Qt\5.6\Src\qtbase\plugins\sqldrivers路径下生成qsqlodbc.dll 和 qsqlodbcd.dll 文件复制到D:\Qt\5.6\mingw49_32\plugins\sqldrivers。

  ⑤Step5:把编译出来的两个.a库放到 %Qt\lib 目录下(即D:\Qt\5.6\mingw49_32\lib)。

【编程实验】简单的数字(0-9)手写识别系统

//recognizer.pro

#-------------------------------------------------
#
# Project created by QtCreator 2017-02-23T20:48:25
#
#-------------------------------------------------

QT       += core gui sql
greaterThan(QT_MAJOR_VERSION, 4): QT += widgets

TARGET = Recognizer
TEMPLATE = app


SOURCES += main.cpp\
        Dialog.cpp \
    Database.cpp \
    Classes.cpp \
    AnalyseFeature.cpp

HEADERS  += Dialog.h \
    Database.h \
    Classes.h \
    AnalyseFeature.h

//Classes.h

#ifndef _CLASSES_H_
#define _CLASSES_H_
#include <QList>

//特征向量(9维)
class FeatureVector
{
    const int  nCount = 9;
    int* arr;
public:
    int  count();
    void init();
    void setData(int index, int value);
    int  getData(int index);
    //用于测试
    QString toString();
    FeatureVector();
    FeatureVector(const FeatureVector& fv);
    FeatureVector& operator=(const FeatureVector& fv);
    ~FeatureVector();
};

//数字(0-9)
class Digit
{
    int nValue;
    QList<FeatureVector> fvs; //数字的特征向量组
public:
    int     count()const;
    FeatureVector& getAt(int index);
    void    append(FeatureVector& fv);
    void    addAt(int index, FeatureVector& fv);
    void    removeAt(int index);
    int     getValue();
    void    setValue(int value);
    Digit&  operator=(const Digit& digit);
    QString toString();
};

class FeatureLibrary
{
    QList<Digit> digits;
public:
    int    count();
    Digit& getAt(int index);
    void   addAt(int index, const Digit& digit);
    void   append(const Digit& digit);
    void   removeAt(int index);
    void   clear();
};

#endif // _CLASSES_H_
View Code

//Classes.cpp

#include "Classes.h"

//特征向量
FeatureVector::FeatureVector()
{
    arr = new int[nCount];
    init();
}

FeatureVector::FeatureVector(const FeatureVector &fv)
{
    arr = new int[nCount];
    for(int i=0; i<nCount; i++){
        arr[i] = fv.arr[i];
    }
}

FeatureVector& FeatureVector::operator=(const FeatureVector &fv)
{
    for(int i=0; i<nCount; i++){
        arr[i] = fv.arr[i];
    }
    return *this;
}

FeatureVector::~FeatureVector()
{
    delete arr;
}
void FeatureVector::init()
{
    for (int i=0; i<nCount; i++){
        arr[i] = 0;
    }
}

int FeatureVector::count()
{
    return nCount;
}

void FeatureVector::setData(int index, int value)
{
    if((0 <= index) && (index < nCount)){
        arr[index] = value;
    }
}

int FeatureVector::getData(int index)
{
    int ret = -1;
    if((0 <= index) && (index <nCount)){
        ret = arr[index];
    }
    return ret;
}

QString FeatureVector::toString()
{
    QString ret = "";
    QString tmp;
    for(int i=0; i<nCount; i++){
        tmp.setNum(arr[i]);
        if(i == (nCount-1)){
             ret += tmp;
        }else{
            ret += tmp + ",";
        }
    }
    return ret;
}

int Digit::count() const
{
    return fvs.count();
}

FeatureVector& Digit::getAt(int index)
{
    return fvs[index];
}

void Digit::append(FeatureVector &fv)
{
    fvs.append(fv);
}

void Digit::addAt(int index, FeatureVector &fv)
{
    fvs.insert(index, fv);
}

void Digit::removeAt(int index)
{
    fvs.removeAt(index);
}

int Digit::getValue()
{
    return nValue;
}

void Digit::setValue(int value)
{
    nValue = value;
}

Digit &Digit::operator=(const Digit &digit)
{
    fvs.clear();
    fvs = digit.fvs;

    return *this;
}

//用来调试时,输出某个数字所有的特征向量
QString Digit::toString()
{
    QString ret = "";
    QString tmp;
    tmp.setNum(nValue);
    ret += tmp+": ";

    for(int i=0; i<fvs.count(); i++){
        if(i == fvs.count() -1){
            ret += fvs[i].toString();
        }else{
            ret += fvs[i].toString() +" ";
        }
    }
    return ret;
}

int FeatureLibrary::count()
{
    return digits.count();
}

Digit& FeatureLibrary::getAt(int index)
{
    return digits[index];
}

void FeatureLibrary::addAt(int index, const Digit &digit)
{
    Digit& dg = getAt(index);
    dg = digit;
    //digits.insert(index, digit);
}

void FeatureLibrary::append(const Digit &digit)
{
    digits.append(digit);
}

void FeatureLibrary::removeAt(int index)
{
    digits.removeAt(index);
}

void FeatureLibrary::clear()
{
   digits.clear();
}

//Database.h

#ifndef _DATABASE_H_
#define _DATABASE_H_
#include <QSqlDatabase>
#include "Classes.h"

class Database
{
private:
    QSqlDatabase db;
public:
    Database();
    void loadFeatureLibrary(FeatureLibrary& fl);
    bool addData(int digit, FeatureVector& fv);
    ~Database();
};

#endif // _DATABASE_H_
View Code

//Database.cpp

#include "Database.h"
#include "Classes.h"
#include <QSqlRecord>
#include <QSqlQuery>
#include <QDebug>

Database::Database()
{
    db = QSqlDatabase::addDatabase("QODBC");
    db.setDatabaseName("DRIVER={Microsoft Access Driver (*.mdb)};FIL={MS Access};DBQ=database.mdb");
}

void Database::loadFeatureLibrary(FeatureLibrary& fl)
{
    if(db.open())
    {
        QSqlQuery query(db);
        fl.clear();
        QString sql;

        for (int i = 0; i<=9; i++){
            sql.clear();
            sql.sprintf("select * from %d", i);
            query.prepare(sql);
            query.exec();

            Digit digit;
            digit.setValue(i);
            FeatureVector fv;

            while(query.next()){               
                int count = qMin(fv.count(), query.record().count());
                //获取特征向量
                fv.init();
                for(int j=0; j<count; j++){
                    fv.setData(j, query.value(j).toInt());
                }
                digit.append(fv);
            }
            fl.append(digit);
        }
        db.close();
    }
}

bool Database::addData(int digit, FeatureVector& fv)
{
    bool ret = false;
    if(db.open())
    {
        QString sql;

        sql.clear();
        sql.sprintf( "insert into %d (0,1,2,3,4,5,6,7,8) ", digit);
        sql += "values(:arg0,:arg1,:arg2,:arg3,:arg4,:arg5,:arg6,:arg7,:arg8)";

        QSqlQuery query(db);
        query.prepare(sql);
        for(int i=0; i<fv.count(); i++){
            QString arg;
            arg.sprintf(":arg%d", i);
            query.bindValue(arg, fv.getData(i));
        }

        ret = query.exec();

        db.close();
    }
    return ret;
}

Database::~Database()
{
    if (db.isOpen())
       db.close();
}

//AnalyseFeature.h

#ifndef _ANALYSEFEATURE_H_
#define _ANALYSEFEATURE_H_
#include <QList>
#include <QPoint>
#include "Classes.h"

class AnalyseFeature
{
private:
    //Standard Deviation
    static double minStdDev(FeatureVector& fv, Digit& digit);    //求特征向量组中最小的标准差
    static double stdDev(FeatureVector &fv1, FeatureVector fv2); //求标准差
public:
    //计算特征向量
    static void makeFeature(const QList<QPoint>& points, FeatureVector& fv);

    //识别图形
    static int recognize(FeatureVector& fv, FeatureLibrary& fl);
};

#endif // _ANALYSEFEATURE_H_
View Code

//AnalyseFeature.cpp

#include "AnalyseFeature.h"
//#include <float.h>  //for DBL_MAX macro
#include <climits>
#include <qmath.h>

//根据给定的数据产生特征向量
void AnalyseFeature::makeFeature(const QList<QPoint>& points, FeatureVector& fv)
{
    int count = points.count();
    int maxX = 0;
    int maxY = 0;
    int minX = 0;
    int minY = 0;
    int x1 = 0;
    int y1 = 0;
    int x2 = 0;
    int y2 = 0;
    int dx = 0;
    int dy = 0;

    if(count <= 0)
        return;

    QList<QPoint>::const_iterator iter = points.begin();
    maxX = (*iter).x();
    maxY = (*iter).y();
    minX = (*iter).x();
    minY = (*iter).y();

    //找出图形最大矩形范围
    while(iter < points.end()){
        if(maxX < (*iter).x()){
            maxX = (*iter).x();
        }
        if(maxY < (*iter).y()){
            maxY = (*iter).y();
        }
        if(minX > (*iter).x()){
            minX = (*iter).x();
        }
        if(minY > (*iter).y()){
            minY = (*iter).y();
        }
        ++iter;
    }

    //分出九宫格
    dx = (maxX - minX) / 3;
    dy = (maxY - minY) / 3;
    x1 = minX + dx;
    y1 = minY + dy;
    x2 = x1 + dx;
    y2 = y1 + dy;

    int array[9] = {0};
    iter = points.begin();
    QPoint pos = *iter;

    while(iter < points.end())
    {
        pos = (*iter);
        //第1行
        if((minX <= pos.x() && pos.x() < x1 )&&
           (minY <= pos.y() && pos.y() < y1)){
            array[0] += 1;
        }
        if((x1 <= pos.x() && pos.x() < x2 )&&
           (minY <= pos.y() && pos.y() < y1)){
            array[1] += 1;
        }
        if((x2 <= pos.x() && pos.x() <= maxX )&&
           (minY <= pos.y() && pos.y() < y1)){
            array[2] += 1;
        }
        //第2行
        if((minX <= pos.x() && pos.x() < x1 )&&
           (y1 <= pos.y() && pos.y() < y2)){
            array[3] += 1;
        }
        if((x1 <= pos.x() && pos.x() < x2 )&&
           (y1 <= pos.y() && pos.y() < y2)){
            array[4] += 1;
        }
        if((x2 <= pos.x() && pos.x() <= maxX )&&
           (y1 <= pos.y() && pos.y() < y2)){
            array[5] += 1;
        }

        //第3行
        if((minX <= pos.x() && pos.x() < x1 )&&
           (y2 <= pos.y() && pos.y() <= maxY)){
            array[6] += 1;
        }
        if((x1 <= pos.x() && pos.x() < x2 )&&
           (y2 <= pos.y() && pos.y() <= maxY)){
            array[7] += 1;
        }
        if((x2 <= pos.x() && pos.x() <= maxX )&&
           (y2 <= pos.y() && pos.y() <= maxY)){
            array[8] += 1;
        }
        ++iter;
    }

    //计算九宫格内点数所占的比率
    for(int i=0; i<9; i++){
        array[i] *= 100;
        array[i] /=count;
        fv.setData(i, array[i]);
    }
}

//minStdDev
double AnalyseFeature::minStdDev(FeatureVector& fv, Digit &digit)
{
    double ret = INT_MAX;
    double temp = 0;

    for(int i=0; i<digit.count(); i++)
    {
        temp = stdDev(fv, digit.getAt(i));
        if(temp < ret){
              ret = temp;
        }
    }
    return ret;
}

//求标准差
double AnalyseFeature::stdDev(FeatureVector &fv1, FeatureVector fv2)
{
    double ret = 0;
    int count = qMin(fv1.count(), fv2.count());

    if (count <= 0)
        return ret;

    for(int i=0; i<count; i++){
        ret += (fv1.getData(i) - fv2.getData(i))*(fv1.getData(i) - fv2.getData(i));
    }

    ret /= count;
    ret = qSqrt(ret);

    return ret;
}

//识别分析(fv为要分析的向量,fl为特征向量库)
int AnalyseFeature::recognize(FeatureVector& fv, FeatureLibrary& fl)
{
    int ret = -1;
    double xxx = INT_MAX;
    double temp = 0;
    for(int i = 0; i<fl.count(); i++){
        Digit& digit = fl.getAt(i);

        temp = minStdDev(fv, digit); //fv与digit的特征向量组对比
        if(temp < xxx){
            xxx = temp;
            ret = digit.getValue();
        }
    }

    return ret;
}

//Dialog.h

#ifndef _DIALOG_H_
#define _DIALOG_H_

#include <QDialog>
#include <QLabel>
#include <QPushButton>
#include <QComboBox>
#include <QList>
#include "Classes.h"
#include "Database.h"
#include "AnalyseFeature.h"

class Dialog : public QDialog
{
    Q_OBJECT

private:
    QLabel noticeLbl;
    QLabel inputLbl;
    QLabel recognizeLbl;
    QPushButton studyBtn;
    QPushButton recognizeBtn;
    QPushButton cleanBtn;
    QComboBox   objective;

    FeatureLibrary featureLibrary;
    Database db;
    QList<QPoint> points;

    //绘制一个图形
    void draw(QPainter& painter, QList<QPoint>& points);
    bool isDown;

protected slots:
    void onStudyClick();
    void onCleanClick();
    void onRecognizeClick();
//protected:
    //void mousePressEvent(QMouseEvent* evt);
    //void mouseMoveEvent(QMouseEvent* evt);
    //void mouseReleaseEvent(QMouseEvent* evt);
    //void paintEvent(QPaintEvent *);
public:
    Dialog(QWidget *parent = 0);
    bool eventFilter(QObject* obj, QEvent* e);
    ~Dialog();
};

#endif // _DIALOG_H_
View Code

//Dialog.cpp

#include "Dialog.h"
#include <QStringList>
#include <QMessageBox>
#include <QPainter.h>
#include <QMouseEvent>
#include <QFont>
#include "Database.h"
#include <QDebug>

Dialog::Dialog(QWidget *parent): QDialog(parent, Qt::WindowMinimizeButtonHint | Qt::WindowCloseButtonHint),
    noticeLbl(this),inputLbl(this),recognizeLbl(this),
    studyBtn(this),recognizeBtn(this),cleanBtn(this),
    objective(this),isDown(false)
{
    setWindowTitle("0-9手写识别系统 [浅墨浓香:25590009]");
    setFixedSize(413, 280);

    //识别匹配
    recognizeLbl.move(31, 62);
    recognizeLbl.setFixedSize(112, 144);
    recognizeLbl.setFrameStyle(QFrame::StyledPanel | QFrame::Raised);
    QFont ft;
    ft.setPointSize(80);
    recognizeLbl.setFont(ft);
    recognizeLbl.setAlignment(Qt::AlignHCenter | Qt::AlignVCenter);


    //手写的区域
    inputLbl.move(144, 38);
    inputLbl.setFixedSize(264, 204);
    inputLbl.installEventFilter(this);

    //noticeLbl
    noticeLbl.move(48, 5);
    noticeLbl.setText("请在屏幕空白处用鼠标输入0-9中的某一个数字进行识别!");
    noticeLbl.setFixedSize(317,42);

    //studyBtn
    studyBtn.move(12, 243);
    studyBtn.setFixedSize(72, 24);
    studyBtn.setText("学习");

    //recognizeBtn
    recognizeBtn.move(238, 243);
    recognizeBtn.setText("识别");
    recognizeBtn.setFixedSize(72, 24);

    //cleanBtn
    cleanBtn.move(326, 243);
    cleanBtn.setFixedSize(72, 24);
    cleanBtn.setText("清屏");

    //objective
    objective.move(100, 246);
    QStringList list;
    for(int i = 0; i<=9; i++){
        objective.addItem(QString::number(i));
    }

    objective.setCurrentIndex(0);

    db.loadFeatureLibrary(featureLibrary);

    //连接信号和槽
    QObject::connect(&studyBtn, SIGNAL(clicked()), this, SLOT(onStudyClick()));
    QObject::connect(&cleanBtn, SIGNAL(clicked()), this, SLOT(onCleanClick()));
    QObject::connect(&recognizeBtn, SIGNAL(clicked()), this, SLOT(onRecognizeClick()));
}

void Dialog::draw(QPainter &painter, QList<QPoint> &points)
{
    for(int i=0; i<points.count() -1; i++){
        painter.drawLine(points[i],points[i+1]);
    }
}

void Dialog::onStudyClick()
{
    if((0 <= objective.currentIndex()) && (objective.currentIndex() <= 9)
             &&(points.count() > 0)){

        FeatureVector fv;

        AnalyseFeature::makeFeature(points, fv);

        Digit& digit = featureLibrary.getAt(objective.currentIndex());

        if(QMessageBox::Ok == QMessageBox::question(this, "学习数字", "确定要学习吗",
                                                    QMessageBox::Ok | QMessageBox::No)){
            bool bOk = db.addData(objective.currentIndex(), fv);
            if (bOk) {
                digit.append(fv);
            }
        }
    }
}

void Dialog::onCleanClick()
{
    points.clear();
    recognizeLbl.setText("");
    inputLbl.update();
}

void Dialog::onRecognizeClick()
{
    if(points.count() <= 0)
        return;

    FeatureVector fv;
    AnalyseFeature::makeFeature(points, fv);

    int num = AnalyseFeature::recognize(fv, featureLibrary);
    if(num != -1)
        recognizeLbl.setNum(num);
    else
        QMessageBox::information(this, "提示信息", "无法识别的数字,请继续学习",QMessageBox::Ok);
}

bool Dialog::eventFilter(QObject *obj, QEvent *e)
{
    bool ret = true; //事件己处理,不再传递
    if(obj == &inputLbl){
        //鼠标按下
        if(e->type() == QEvent::MouseButtonPress)
        {
            QMouseEvent* evt = dynamic_cast<QMouseEvent*>(e);
            isDown = true;
            points.append(evt->pos());
        //鼠标移动
        }else if(e->type() == QEvent::MouseMove){
            QMouseEvent* evt = dynamic_cast<QMouseEvent*>(e);
            if(isDown){
                points.append(evt->pos());
                inputLbl.update();
            }
         //释放鼠标
        }else if(e->type() == QEvent::MouseButtonRelease){
            isDown = false;
        //重绘事件
        }else if(e->type() == QEvent::Paint){
            QPainter painter(&inputLbl);
            draw(painter, points);
        }else{
            ret = false;
        }
    }else{
        ret = QDialog::eventFilter(obj, e);
    }

    return ret;
}

//void Dialog::mousePressEvent(QMouseEvent *evt)
//{
//    isDown = true;
//    points.append(evt->pos());
//}

//void Dialog::mouseMoveEvent(QMouseEvent *evt)
//{
//    if(isDown){
//        points.append(evt->pos());
//        update();
//    }
//}

//void Dialog::mouseReleaseEvent(QMouseEvent *evt)
//{
//    Q_UNUSED(evt);
//    isDown = false;
//}

//void Dialog::paintEvent(QPaintEvent*)
//{
//    QPainter painter(this);
//    draw(painter, points);
//}

Dialog::~Dialog()
{

}

//main.cpp

#include "Dialog.h"
#include <QApplication>

int main(int argc, char *argv[])
{
    QApplication a(argc, argv);
    Dialog w;
    w.show();

    return a.exec();
}