关灯

学习SVM(一) SVM模型训练与分类的OpenCV实现

[复制链接]
admin 发表于 2019-1-20 13:13:52 | 显示全部楼层 |阅读模式 打印 上一主题 下一主题
 

Andrew Ng 在斯坦福大学的机器学习公开课上这样评价支持向量机: 
support vector machines is the supervised learning algorithm that many people consider the most effective off-the-shelf supervised learning algorithm.That point of view is debatable,but there are many people that hold that point of view.

可见,在监督学习算法中支持向量机有着非常广泛的应用,而且在解决图像分类问题时有着优异的效果。

OpenCV集成了这种学习算法,它被包含在ml模块下的CvSVM类中,下面我们用OpenCV实现SVM的数据准备模型训练加载模型实现分类,为了理解起来更加直观,我们用三个工程来实现。

  数据准备

在OpenCV的安装路径下,搜索digits,可以得到一张图片,图片大小为1000*2000,有0-9的10个数字,每5行为一个数字,总共50行,共有5000个手写数字,每个数字块大小为20*20。 下面将把这些数字中的0和1作为二分类的准备数据。其中0有500张,1有500张。 
用下面的代码将图片准备好,在写入路径提前建立好文件夹:

  1. #include <opencv2/opencv.hpp>
  2. #include <iostream>
  3. using namespace std;
  4. using namespace cv;
  5. int main()
  6. {
  7. char ad[128]={0};
  8. int filename = 0,filenum=0;
  9. Mat img = imread("digits.png");
  10. Mat gray;
  11. cvtColor(img, gray, CV_BGR2GRAY);
  12. int b = 20;
  13. int m = gray.rows / b; //原图为1000*2000
  14. int n = gray.cols / b; //裁剪为5000个20*20的小图块
  15. for (int i = 0; i < m; i++)
  16. {
  17. int offsetRow = i*b; //行上的偏移量
  18. if(i%5==0&&i!=0)
  19. {
  20. filename++;
  21. filenum=0;
  22. }
  23. for (int j = 0; j < n; j++)
  24. {
  25. int offsetCol = j*b; //列上的偏移量
  26. sprintf_s(ad, "D:\\data\\%d\\%d.jpg",filename,filenum++);
  27. //截取20*20的小块
  28. Mat tmp;
  29. gray(Range(offsetRow, offsetRow + b), Range(offsetCol, offsetCol + b)).copyTo(tmp);
  30. imwrite(ad,tmp);
  31. }
  32. }
  33. return 0;
  34. }
复制代码

 

最后可以得到这样的结果: 

1532925645478175979.jpg


组织的二分类数据形式为:

  1. --D:
  2. --data
  3. --train_image
  4. --0(400张)
  5. --1(400张)
  6. --test_image
  7. --0(100张)
  8. --1(100张)
复制代码

 

1532925683950422480.jpg

 
训练数据800张,0,1各400张;测试数据200张,0,1各100张

  模型训练

数据准备完成之后,就可以用下面的代码训练了:

  1. #include <stdio.h>
  2. #include <time.h>
  3. #include <opencv2/opencv.hpp>
  4. #include <opencv/cv.h>
  5. #include <iostream>
  6. #include <opencv2/core/core.hpp>
  7. #include <opencv2/highgui/highgui.hpp>
  8. #include <opencv2/ml/ml.hpp>
  9. #include <io.h>
  10. using namespace std;
  11. using namespace cv;
  12. void getFiles( string path, vector<string>& files);
  13. void get_1(Mat& trainingImages, vector<int>& trainingLabels);
  14. void get_0(Mat& trainingImages, vector<int>& trainingLabels);
  15. int main()
  16. {
  17. //获取训练数据
  18. Mat classes;
  19. Mat trainingData;
  20. Mat trainingImages;
  21. vector<int> trainingLabels;
  22. get_1(trainingImages, trainingLabels);
  23. get_0(trainingImages, trainingLabels);
  24. Mat(trainingImages).copyTo(trainingData);
  25. trainingData.convertTo(trainingData, CV_32FC1);
  26. Mat(trainingLabels).copyTo(classes);
  27. //配置SVM训练器参数
  28. CvSVMParams SVM_params;
  29. SVM_params.svm_type = CvSVM::C_SVC;
  30. SVM_params.kernel_type = CvSVM::LINEAR;
  31. SVM_params.degree = 0;
  32. SVM_params.gamma = 1;
  33. SVM_params.coef0 = 0;
  34. SVM_params.C = 1;
  35. SVM_params.nu = 0;
  36. SVM_params.p = 0;
  37. SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
  38. //训练
  39. CvSVM svm;
  40. svm.train(trainingData, classes, Mat(), Mat(), SVM_params);
  41. //保存模型
  42. svm.save("svm.xml");
  43. cout<<"训练好了!!!"<<endl;
  44. getchar();
  45. return 0;
  46. }
  47. void getFiles( string path, vector<string>& files )
  48. {
  49. long hFile = 0;
  50. struct _finddata_t fileinfo;
  51. string p;
  52. if((hFile = _findfirst(p.assign(path).append("\\*").c_str(),&fileinfo)) != -1)
  53. {
  54. do
  55. {
  56. if((fileinfo.attrib & _A_SUBDIR))
  57. {
  58. if(strcmp(fileinfo.name,".") != 0 && strcmp(fileinfo.name,"..") != 0)
  59. getFiles( p.assign(path).append("\").append(fileinfo.name), files );
  60. }
  61. else
  62. {
  63. files.push_back(p.assign(path).append("\").append(fileinfo.name) );
  64. }
  65. }while(_findnext(hFile, &fileinfo) == 0);
  66. _findclose(hFile);
  67. }
  68. }
  69. void get_1(Mat& trainingImages, vector<int>& trainingLabels)
  70. {
  71. char * filePath = "D:\\data\\train_image\\1";
  72. vector<string> files;
  73. getFiles(filePath, files );
  74. int number = files.size();
  75. for (int i = 0;i < number;i++)
  76. {
  77. Mat SrcImage=imread(files[i].c_str());
  78. SrcImage= SrcImage.reshape(1, 1);
  79. trainingImages.push_back(SrcImage);
  80. trainingLabels.push_back(1);
  81. }
  82. }
  83. void get_0(Mat& trainingImages, vector<int>& trainingLabels)
  84. {
  85. char * filePath = "D:\\data\\train_image\\0";
  86. vector<string> files;
  87. getFiles(filePath, files );
  88. int number = files.size();
  89. for (int i = 0;i < number;i++)
  90. {
  91. Mat SrcImage=imread(files[i].c_str());
  92. SrcImage= SrcImage.reshape(1, 1);
  93. trainingImages.push_back(SrcImage);
  94. trainingLabels.push_back(0);
  95. }
  96. }
复制代码

 

整个训练过程可以分为一下几个部分: 
数据准备: 
该例程中一个定义了三个子程序用来实现数据准备工作: 

  • getFiles()用来遍历文件夹下所有文件。
  • getBubble()用来获取有气泡的图片和与其对应的Labels,该例程将Labels定为1。 
  • getNoBubble()用来获取没有气泡的图片与其对应的Labels,该例程将Labels定为0。 
  • getBubble()与getNoBubble()将获取一张图片后会将图片(特征)写入到容器中,紧接着会将标签写入另一个容器中,这样就保证了特征和标签是一一对应的关系push_back(0)或者push_back(1)其实就是我们贴标签的过程。
  1. trainingImages.push_back(SrcImage);
  2. trainingLabels.push_back(0);
复制代码

 

在主函数中,将getBubble()与getNoBubble()写好的包含特征的矩阵拷贝给trainingData,将包含标签的vector容器进行类型转换后拷贝到trainingLabels里,至此,数据准备工作完成,trainingData与trainingLabels就是我们要训练的数据。

  1. Mat classes;
  2. Mat trainingData;
  3. Mat trainingImages;
  4. vector<int> trainingLabels;
  5. getBubble(trainingImages, trainingLabels);
  6. getNoBubble(trainingImages, trainingLabels);
  7. Mat(trainingImages).copyTo(trainingData);
  8. trainingData.convertTo(trainingData, CV_32FC1);
  9. Mat(trainingLabels).copyTo(classes);
复制代码

 

特征选取 
其实特征提取和数据的准备是同步完成的,我们最后要训练的也是正负样本的特征。本例程中同样在getBubble()与getNoBubble()函数中完成特征提取工作,只是我们简单粗暴将整个图的所有像素作为了特征,因为我们关注更多的是整个的训练过程,所以选择了最简单的方式完成特征提取工作,除此中外,特征提取的方式有很多,比如LBP,HOG等等。

  1. SrcImage= SrcImage.reshape(1, 1);
复制代码

 

我们利用reshape()函数完成特征提取,原型如下:

  1. Mat reshape(int cn, int rows=0) const;
复制代码

 

可以看到该函数的参数非常简单,cn为新的通道数,如果cn = 0,表示通道数不会改变。参数rows为新的行数,如果rows = 0,表示行数不会改变。我们将参数定义为reshape(1, 1)的结果就是原图像对应的矩阵将被拉伸成一个一行的向量,作为特征向量。 
 

参数配置 
参数配置是SVM的核心部分,在Opencv中它被定义成一个结构体类型,如下:

  1. struct CV_EXPORTS_W_MAP CvSVMParams
  2. {
  3. CvSVMParams();
  4. CvSVMParams(
  5. int svm_type,
  6. int kernel_type,
  7. double degree,
  8. double coef0,
  9. double Cvalue,
  10. double p,
  11. CvMat* class_weights,
  12. CvTermCriteria term_crit );
  13. CV_PROP_RW int svm_type;
  14. CV_PROP_RW int kernel_type;
  15. CV_PROP_RW double degree; // for poly
  16. CV_PROP_RW double gamma; // for poly/rbf/sigmoid
  17. CV_PROP_RW double coef0; // for poly/sigmoid
  18. CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
  19. CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
  20. CV_PROP_RW double p; // for CV_SVM_EPS_SVR
  21. CvMat* class_weights; // for CV_SVM_C_SVC
  22. CV_PROP_RW CvTermCriteria term_crit; // termination criteria
  23. };
复制代码

 

所以在例程中我们定义了一个结构体变量用来配置这些参数,而这个变量也就是CVSVM类中train函数的第五个参数,下面对参数进行说明。 

  • SVM_params.svm_type :SVM的类型: 
  • C_SVC表示SVM分类器,C_SVR表示SVM回归 
  • SVM_params.kernel_type:核函数类型 
  • 线性核LINEAR: 
  • d(x,y)=(x,y) 
  • 多项式核POLY: 
  • d(x,y)=(gamma*(x’y)+coef0)degree 
  • 径向基核RBF: 
  • d(x,y)=exp(-gamma*|x-y|^2) 
  • sigmoid核SIGMOID: 
  • d(x,y)= tanh(gamma*(x’y)+ coef0)
  • SVM_params.degree:核函数中的参数degree,针对多项式核函数; 
  • SVM_params.gama:核函数中的参数gamma,针对多项式/RBF/SIGMOID核函数; 
  • SVM_params.coef0:核函数中的参数,针对多项式/SIGMOID核函数; 
  • SVM_params.c:SVM最优问题参数,设置C-SVC,EPS_SVR和NU_SVR的参数; 
  • SVM_params.nu:SVM最优问题参数,设置NU_SVC, ONE_CLASS 和NU_SVR的参数; 
  • SVM_params.p:SVM最优问题参数,设置EPS_SVR 中损失函数p的值. 

 

训练模型

  1. CvSVM svm;
  2. svm.train(trainingData, classes, Mat(), Mat(), SVM_params);
复制代码

 

通过上面的过程,我们准备好了待训练的数据和训练需要的参数,其实可以理解为这个准备工作就是在为svm.train()函数准备实参的过程。来看一下svm.train()函数,Opencv将SVM封装成CvSVM库,这个库是基于台湾大学林智仁(Lin Chih-Jen)教授等人开发的LIBSVM封装的,由于篇幅限制,不再全部粘贴库的定义,所以一下代码只是CvSVM库中的一部分数据和函数:

  1. class CV_EXPORTS_W CvSVM : public CvStatModel
  2. {
  3. public:
  4. virtual bool train(
  5. const CvMat* trainData,
  6. const CvMat* responses,
  7. const CvMat* varIdx=0,
  8. const CvMat* sampleIdx=0,
  9. CvSVMParams params=CvSVMParams() );
  10. virtual float predict(
  11. const CvMat* sample,
  12. bool returnDFVal=false ) const;
复制代码

 

我们就是应用类中定义的train函数完成模型训练工作。 
 

保存模型

  1. svm.save("svm.xml");
复制代码

 

保存模型只有一行代码,利用save()函数,我们看下它的定义:

  1. CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
复制代码

 

该函数被定义在CvStatModel类中,CvStatModel是ML库中的统计模型基类,其他 ML 类都是从这个类中继承。

总结:到这里我们就完成了模型训练工作,可以看到真正用于训练的代码其实很少,OpenCV最支持向量机的封装极大地降低了我们的编程工作。

  加载模型实现分类
  1. #include <stdio.h>
  2. #include <time.h>
  3. #include <opencv2/opencv.hpp>
  4. #include <opencv/cv.h>
  5. #include <iostream>
  6. #include <opencv2/core/core.hpp>
  7. #include <opencv2/highgui/highgui.hpp>
  8. #include <opencv2/ml/ml.hpp>
  9. #include <io.h>
  10. using namespace std;
  11. using namespace cv;
  12. void getFiles( string path, vector<string>& files );
  13. int main()
  14. {
  15. int result = 0;
  16. char * filePath = "D:\\data\\test_image\\0";
  17. vector<string> files;
  18. getFiles(filePath, files );
  19. int number = files.size();
  20. cout<<number<<endl;
  21. CvSVM svm;
  22. svm.clear();
  23. string modelpath = "svm.xml";
  24. FileStorage svm_fs(modelpath,FileStorage::READ);
  25. if(svm_fs.isOpened())
  26. {
  27. svm.load(modelpath.c_str());
  28. }
  29. for (int i = 0;i < number;i++)
  30. {
  31. Mat inMat = imread(files[i].c_str());
  32. Mat p = inMat.reshape(1, 1);
  33. p.convertTo(p, CV_32FC1);
  34. int response = (int)svm.predict(p);
  35. if (response == 0)
  36. {
  37. result++;
  38. }
  39. }
  40. cout<<result<<endl;
  41. getchar();
  42. return 0;
  43. }
  44. void getFiles( string path, vector<string>& files )
  45. {
  46. long hFile = 0;
  47. struct _finddata_t fileinfo;
  48. string p;
  49. if((hFile = _findfirst(p.assign(path).append("\\*").c_str(),&fileinfo)) != -1)
  50. {
  51. do
  52. {
  53. if((fileinfo.attrib & _A_SUBDIR))
  54. {
  55. if(strcmp(fileinfo.name,".") != 0 && strcmp(fileinfo.name,"..") != 0)
  56. getFiles( p.assign(path).append("\").append(fileinfo.name), files );
  57. }
  58. else
  59. { files.push_back(p.assign(path).append("\").append(fileinfo.name) );
  60. }
  61. }while(_findnext(hFile, &fileinfo) == 0);
  62. _findclose(hFile);
  63. }
  64. }
复制代码

 

在上面我们把该介绍的都说的差不多了,这个例程中只是用到了load()函数用于模型加载,加载的就是上面例子中生成的模型,load()被定义在CvStatModel这个基类中:

  1. svm.load(modelpath.c_str());
复制代码

  

load的路径是string modelpath = "svm.xml",这意味着svm.mxl文件应该在测试工程的根目录下面,但是因为训练和预测是两个**的工程,所以必须要拷贝一下这个文件。最后用到predict()函数用来预测分类结果,predict()被定义在CVSVM类中。

注意:

1.为什么要建立三个**的工程呢? 
主要是考虑写在一起话,代码量会比较大,逻辑没有分开清晰,当跑通上面的代码之后,就可以随意的改了。 
2.为什么加上数据准备? 
之前有评论说道数据的问题,提供数据后实验能更顺利一些,因为本身代码没有什么含金量,这样可以更顺利的运行起来工程,并修改它。 
3.一些容易引起异常的情况: 
(1)注意生成的.xml记得拷贝到预测工程下; 
(2)注意准备好数据路径和代码是不是一致; 
(3)注意训练的特征要和测试的特征一致。

回复

使用道具 举报

 
*滑块验证:
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则


1关注

0粉丝

1603帖子

排行榜

关注我们:微信订阅号

官方微信

APP下载

全国服务热线:

4000-018-018

公司地址:上海市嘉定区银翔路655号B区1068室

运营中心:成都市锦江区东华正街42号广电仕百达国际大厦25楼

邮编:610066 Email:3318850993#qq.com

Copyright   ©2015-2016  比特趋势Powered by©Discuz!技术支持:迪恩网络