Python 支持向量机(svm手写字体识别实战)
一、线性可分svm.LinearSVC()
1. 使用网格搜索法,选择线性可分SVM“类”中的最佳C值
1 # 导入第三方模块 2 from sklearn import svm 3 import pandas as pd 4 from sklearn import model_selection 5 from sklearn import metrics 6 7 # 读取外部数据 8 Hw = pd.read_csv(r'Handwritten.csv') 9 # 数据特征个数 10 11 #选取特征列 12 predict = Hw.columns[:1024] 13 Train = Hw[predict] 14 # 打印数据前5行 15 Train.head() 16 17 # 选取标签列 18 label = Hw[Hw.columns[1024]] 19 print(label.head()) 20 21 X_train,X_test,y_train,y_test = model_selection.train_test_split(Train,label,test_size=0.25,random_state=123) 22 23 # 使用网格搜索法,选择线性可分SVM“类”中的最佳C值 24 C=[0.05,0.1,0.5,1,2,5] 25 parameters = {'C':C} 26 grid_linear_svc = model_selection.GridSearchCV(estimator = svm.LinearSVC(),param_grid =parameters,scoring='accuracy',cv=5,verbose =1) 27 # 模型在训练数据集上的拟合 28 grid_linear_svc.fit(X_train,y_train) 29 # 返回交叉验证后的最佳参数值 30 grid_linear_svc.best_params_, grid_linear_svc.best_score_
结果:
({'C': 0.05}, 0.9675775822139879)
2.模型训练与预测
1 F_svc = svm.LinearSVC(C=0.05) 2 F_svc.fit(X_train,y_train) 3 # 模型在测试集上的预测 4 pred = F_svc.predict(X_test) 5 # # 模型的预测准确率 6 metrics.accuracy_score(y_test, pred)
结果:
0.9541666666666667
3.混淆矩阵绘制
1 # 导入第三方模块 2 from sklearn import metrics 3 # 混淆矩阵 4 cm = metrics.confusion_matrix(y_test, pred) 5 6 import seaborn as sns 7 import matplotlib.pyplot as plt 8 sns.heatmap(cm,annot = True,cmap = 'GnBu') 9 plt.xlabel(' Real Lable') 10 plt.ylabel(' Predict Lable')
结果:

二、非线性svm.SVC()
1.使用网格搜索法,选择非线性SVM“类”中的最佳C值与核函数
1 # 使用网格搜索法,选择非线性SVM“类”中的最佳C值 2 kernel=['rbf','linear','poly','sigmoid'] 3 C=[0.1,0.5,1,2,5] 4 parameters = {'kernel':kernel,'C':C} 5 grid_svc = model_selection.GridSearchCV(estimator = svm.SVC(),param_grid =parameters,scoring='accuracy',cv=5,verbose =1) 6 # 模型在训练数据集上的拟合 7 grid_svc.fit(X_train,y_train) 8 # 返回交叉验证后的最佳参数值 9 grid_svc.best_params_, grid_svc.best_score_
结果:
({'C': 5, 'kernel': 'rbf'}, 0.9856415006947661)
2.模型训练与预测
1 svm_svc = svm.SVC(C=5,kernel='rbf') 2 svm_svc.fit(X_train,y_train) 3 4 # 模型在测试集上的预测 5 pred_svc = grid_svc.predict(X_test) 6 # 模型的预测准确率 7 metrics.accuracy_score(y_test,pred_svc)
结果:
0.9763888888888889
3.混淆矩阵绘制
1 from sklearn import metrics 2 cm = metrics.confusion_matrix(y_test,pred_svc) 3 4 import seaborn as sns 5 import matplotlib.pyplot as plt 6 sns.heatmap(cm,annot=True,cmap='PuBu_r') 7 plt.xlabel(' Real Lable') 8 plt.ylabel(' Predict Lable')
结果:

三、总结
在手写数字识别上,非线性SVC预测的准确率比线性可分的要高一点,在模型选择时,我们可以根据理论知识适当选取合适模型,然后参数调优。
浙公网安备 33010602011771号