以代码为基础的opencv-python学习 利用词袋和SVM进行汽车检验

'''
申明:
Code来源于《Learning OpenCV 3 Computer Vision with Python》
'''
import cv2
import numpy as np
from os.path import join

#申明训练图像基础路径
datapath = 'C:\\root\\learn\\python\\img\\TrainImages' #UIUC 汽车数据图片集

#训练集中图片名称为pos-1.pgm,pos-2.pgm neg-1.pgm,
def path(cls,i):
return "%s/%s%d.pgm"%(datapath,cls,i+1)

pos, neg = "pos-","neg-"

#创建两个SIFT实例,一个用于提取关键点,一个用于提取特征
detect = cv2.xfeatures2d.SIFT_create()
extract = cv2.xfeatures2d.SIFT_create()

#定义SIFT的特征匹配算法,使用FLANN匹配器
flann_params = dict(algorithm=1,trees=5)
flann = cv2.FlannBasedMatcher(flann_params,{})

#创建BOW(词袋)训练器
bow_kmeans_trainer = cv2.BOWKMeansTrainer(40)
#初始化BOW提取器
extract_bow = cv2.BOWImgDescriptorExtractor(extract,flann)

#得到描述符
def extract_sift(fn):
im = cv2.imread(fn,0)
return extract.compute(im,detect.detect(im))[1]

#读取8个正负样本
for i in range(8):
bow_kmeans_trainer.add(extract_sift(path(pos,i)))
bow_kmeans_trainer.add(extract_sift(path(neg,i)))
#使用cluster来创建视觉单词词汇,使用k-means分类
voc = bow_kmeans_trainer.cluster()
extract_bow.setVocabulary(voc)

#返回基于BOW的描述符提取器计算得到的描述符
def bow_features(fn):
im = cv2.imread(fn,0)
return extract_bow.compute(im,detect.detect(im))

#进行监督学习
traindata, trainlabels = [],[]
for i in range(20):
#print("pos path=",path(pos,i))
traindata.extend(bow_features(path(pos,i)))
trainlabels.append(1)
traindata.extend(bow_features(path(neg,i)))
trainlabels.append(-1)

#创建SVM并使用SVM进行训练
svm = cv2.ml.SVM_create()
svm.train(np.array(traindata),cv2.ml.ROW_SAMPLE,np.array(trainlabels))

#预测函数
def predict(fn):
f = bow_features(fn)
p = svm.predict(f)
print(fn,"*",p[1][0][0])
return p

car, notcar = "car_blur.jpg","woman.jpg"
car_img = cv2.imread(car)
notcar_img = cv2.imread(notcar)

car_predict = predict(car)
not_car_predict = predict(notcar)

font = cv2.FONT_HERSHEY_SIMPLEX

if car_predict[1][0][0]==1.0:
cv2.putText(car_img,'Car Detected',(10,30),font,1,(0,255,255),2,cv2.LINE_AA)

if not_car_predict[1][0][0]==-1.0:
cv2.putText(notcar_img, 'Car Not Detected', (10, 30), font, 1, (0, 0, 255), 2, cv2.LINE_AA)

cv2.imshow('BOW+SVM Success',car_img)
cv2.imshow('BOW+SVM Failure',notcar_img)
cv2.waitKey(0)
cv2.destroyAllWindows()

posted on 2020-04-11 16:17  画扇2020  阅读(422)  评论(0)    收藏  举报