svm的二分类opencv源代码

#include"highgui.h"
#include "cv.h"
#include "ml.h"
#include"cxcore.h"
#include<stdio.h>
#include<time.h>
#include<stdlib.h>
#pragma comment(lib,"cv210d.lib")
#pragma comment(lib,"highgui210d.lib")
#pragma comment(lib,"cxcore210d.lib")
#pragma comment(lib,"ml210d.lib")

///定义一个svm的类
CvSVM SVM ;

IplImage *train_image = cvCreateImage(cvSize(100,100),IPL_DEPTH_8U,3);
IplImage *result_image = cvCreateImage(cvSize(100,100),IPL_DEPTH_8U,3);
IplImage *vector_image = cvCreateImage(cvSize(100,100),IPL_DEPTH_8U,3);
IplImage *support_vector_image = cvCreateImage(cvSize(100,100),IPL_DEPTH_8U,3);

///构造SVM并训练函数
void SVMml()
{
 //产生训练数据,写入到train.txt中
 FILE *fp = fopen("train.txt","w");
     int i,a[5000],b[5000];
 int j;
 fprintf(fp,"%d\n",5000);//先写入数组长度
   srand(1);///固定训练样本
 for(i=0;i<5000;i++)
 { a[i] = rand()%100;
   b[i] = rand()%100;
   //printf("%d ",a[i]);
 
   if( 0<=b[i]&&b[i]<=50 )
   {  j=1;
      fprintf(fp,"%d %d %d\n",a[i],b[i],j);}
   if ( 50<b[i]&&b[i]<100)
    {  j=2;
   fprintf(fp,"%d %d %d\n",a[i],b[i],j);}
  
 }

 fclose(fp);//写完关闭
 
  int data[5000][3];
  int sample_count;
 
  ///读数据
  FILE *fp1 = fopen("train.txt","r");
  fscanf(fp1,"%d",&sample_count);
  printf("Find %d samples !\n",sample_count);
  for(int i=0;i<sample_count;i++)
  {
    fscanf(fp1,"%d %d %d",&data[i][0],&data[i][1],&data[i][2]);
  
  }
  fclose(fp1);
  
  CvMat *input_data,*output_data;
  input_data = cvCreateMat(5000,2,CV_32FC1);
  output_data = cvCreateMat(5000,1,CV_32SC1);
 
  ///给矩阵赋初值
  for(int i=0;i<sample_count;i++)
  {
   cvSetReal2D(input_data,i,0,(float)data[i][0]);
   cvSetReal2D(input_data,i,1,(float)data[i][1]);
   cvSetReal1D(output_data,i,(float)data[i][2]);
 
  }
  ///drawing train_image
   for(int i=0;i<sample_count;i++)
   {
    if(data[i][2]==1)
     cvSet2D(train_image,data[i][0],data[i][1],cvScalar(0,0,255));
    if(data[i][2]==2)
     cvSet2D(train_image,data[i][0],data[i][1],cvScalar(0,255,0));
   
        
      }
 
  CvTermCriteria criteria = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS,1000,0.0001);
  CvSVMParams param = CvSVMParams(CvSVM::C_SVC,CvSVM::RBF,NULL,0.1, NULL, 3, NULL, NULL,NULL,criteria);
  ///训练
    /*degree – 内核函数(POLY)的参数degree。
 gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数 。
 coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。
 Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。
 nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 。
 p – SVM类型(EPS_SVR)的参数  。
 class_weights – C_SVC中的可选权重,赋给指定的类,乘以C今后变成  。
 所以这些权重影响不合类此外错误分类处罚项。权重越大,某一类此外误分类数据的处罚项就越大。
 term_crit – SVM的迭代练习过程的中断前提,解决项目组受束缚二次最优题目。
 您可以指定的公差和/或最大迭代次数。*/


 SVM.train(input_data,output_data,NULL,NULL,param);
 cvReleaseMat(&input_data);
 cvReleaseMat(&output_data);
 }

  ///输入测试数据预测输出结果函数,返回预测值
 int preout(int data1,int data2)
  {  
   

     CvMat *test_data = cvCreateMat(2,1,CV_32FC1);
  cvSetReal1D(test_data,0,(float)data1);
  cvSetReal1D(test_data,1,(float)data2);
  float c = SVM.predict(test_data);///注意这里:SVM.predict(test_data,true)结果会不同
  printf("%f ",c);
  int result = cvRound(SVM.predict(test_data));
  return result;
  cvReleaseMat(&test_data);


}
///主函数
 int main(int argc,char*argv[])
{  
 double ems = 0;
 int count = 0;
 cvZero(vector_image);
 cvZero(result_image);

 SVMml();

 int sv_num;
 const float *support_vector;
 sv_num = SVM.get_support_vector_count();
 ///支持向量的绘制
    printf("支持向量个数为:%d\n",sv_num);///支持向量的个数
 for(int i=0;i<sv_num;i++)
 {
   support_vector = SVM.get_support_vector(i);
      cvSet2D(support_vector_image,(int)(support_vector[0]),(int)(support_vector[1]),cvScalar(0,255,0));
   cvCircle(vector_image,cvPoint( (int)(support_vector[1]),(int)(support_vector[0]) ),
      1, cvScalar(255,255,255,0), -1,8,0);
 }
 // free(support_vector);
 //support_vector = NULL;
 
  int a[100],b[100];
  int result;
  int j;
  FILE *fp = fopen("test.txt","w");
  srand((unsigned )time(NULL));
    for(int i=0;i<100;i++)
 {
  
    a[i] = (rand()%100);
    b[i] = (rand()%100);
   
  if( 0<=b[i]&&b[i]<=50 )
  { 
    j=1;
    fprintf(fp,"%d %d %d\n",a[i],b[i],j);
  }
  if ( 50<b[i]&&b[i]<100)
  { 
    j=2;
    fprintf(fp,"%d %d %d\n",a[i],b[i],j);
  }
      ///预测
   result = preout(a[i],b[i]);

     if( result==1 )
    cvSet2D(result_image,a[i],b[i],cvScalar(0,0,255));
  ///红色代表分类1
     if( result==2 )
     cvSet2D(result_image,a[i],b[i],cvScalar(0,255,0));
   ///绿色代表分类2
  //判断误差
  if(result!=j)
   count++;
   printf("%d ",result);
    if( (i+1)%5==0)
    {
     printf("\n");
     fprintf(fp,"\n\n");
     
    }
 }
 fclose(fp);

 ems = count/100.0;
 printf("误差为:%f",ems);

 cvNamedWindow("训练数据图",0);
 cvNamedWindow("支持向量点图",0);
 cvNamedWindow("支持向量图",0);
 cvNamedWindow("输出结果图",0);
 cvShowImage("训练数据图",train_image);
 cvShowImage("支持向量图",vector_image);
 cvShowImage("支持向量点图",support_vector_image);
 cvShowImage("输出结果图",result_image);
   
    cvWaitKey(0);

   cvReleaseImage(&result_image);
   cvDestroyWindow("输出结果图");
   cvReleaseImage(&train_image);
   cvDestroyWindow("训练数据图");
   cvReleaseImage(&vector_image);
   cvDestroyWindow("支持向量图");
   cvReleaseImage(&support_vector_image);
   cvDestroyWindow("支持向量点图");

 return 0;
 }

posted on 2012-12-06 15:14  dutyhong  阅读(1306)  评论(0编辑  收藏  举报

导航