转 opencv红绿灯检测

整个项目源码:GitHub

引言

前面我们讲完交通标志的识别,现在我们开始尝试来实现交通信号灯的识别 
接下来我们将按照自己的思路来实现并完善整个Project. 
在这个项目中,我们使用HSV色彩空间来识别交通灯,可以改善及提高的地方:

  • 可以采用Faster-RCNN或SSD来实现交通灯的识别

首先我们第一步是导入数据,并在RGB及HSV色彩空间可视化部分数据。这里的数据,我们采用MIT自动驾驶课程的图片, 
总共三类:红绿黄,1187张图片,其中,723张红色交通灯图片,429张绿色交通灯图片,35张黄色交通灯图片。

导入库

  1. # import some libs
  2. import cv2
  3. import os
  4. import glob
  5. import random
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import matplotlib.image as mpimg
  9. %matplotlib inline
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  1. # Image data directories
  2. IMAGEDIR_TRAINING = "traffic_light_images/training/"
  3. IMAGE_DIR_TEST = "traffic_light_images/test/"
  4. #load data
  5. def load_dataset(image_dir):
  6. '''
  7. This function loads in images and their labels and places them in a list
  8. image_dir:directions where images stored
  9. '''
  10. im_list =[]
  11. image_types= ['red','yellow','green']
  12. #Iterate through each color folder
  13. for im_type in image_types:
  14. file_lists = glob.glob(os.path.join(image_dir,im_type,'*'))
  15. print(len(file_lists))
  16. for file in file_lists:
  17. im = mpimg.imread(file)
  18. if not im is None:
  19. im_list.append((im,im_type))
  20. return im_list
  21. IMAGE_LIST = load_dataset(IMAGE_DIR_TRAINING)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  1. 723
  2. 35
  3. 429
  • 1
  • 2
  • 3
  • 4

Visualize the data

这里可视化主要实现:

  • 显示图像
  • 打印出图片的大小
  • 打印出图片对应的标签
  1. ,ax = plt.subplots(1,3,figsize=(5,2))
  2. #red
  3. imgred = IMAGE_LIST[0][0]
  4. ax[0].imshow(img_red)
  5. ax[0].annotate(IMAGE_LIST[0][1],xy=(2,5),color='blue',fontsize='10')
  6. ax[0].axis('off')
  7. ax[0].set_title(img_red.shape,fontsize=10)
  8. #yellow
  9. img_yellow = IMAGE_LIST[730][0]
  10. ax[1].imshow(img_yellow)
  11. ax[1].annotate(IMAGE_LIST[730][1],xy=(2,5),color='blue',fontsize='10')
  12. ax[1].axis('off')
  13. ax[1].set_title(img_yellow.shape,fontsize=10)
  14. #green
  15. img_green = IMAGE_LIST[800][0]
  16. ax[2].imshow(img_green)
  17. ax[2].annotate(IMAGE_LIST[800][1],xy=(2,5),color='blue',fontsize='10')
  18. ax[2].axis('off')
  19. ax[2].set_title(img_green.shape,fontsize=10)
  20. plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

png

PreProcess Data

在导入了上述数据后,接下来我们需要标准化输入及输出

Input

从上图,我们可以看出,每张图片的大小并不一样,我们需要标准化输入 
将每张图图片的大小resize成相同的大小, 
因为对于分类任务来说,我们需要 
在每张图片上应用相同的算法,因此标准化图像尤其重要

Output

这里我们的标签数据是类别数据:’red’,’yellow’,’green’,因此我们可以利用one_hot方法将类别数据转换成数值数据

  1. # 标准化输入图像,这里我们resize图片大小为32x32x3,这里我们也可以对图像进行裁剪、平移、旋转
  2. def standardize(image_list):
  3. '''
  4. This function takes a rgb image as input and return a standardized version
  5. image_list: image and label
  6. '''
  7. standard_list = []
  8. #Iterate through all the image-label pairs
  9. for item in image_list:
  10. image = item[0]
  11. label = item[1]
  12. # Standardize the input
  13. standardized_im = standardize_input(image)
  14. # Standardize the output(one hot)
  15. one_hot_label = one_hot_encode(label)
  16. # Append the image , and it's one hot encoded label to the full ,processed list of image data
  17. standard_list.append((standardized_im,one_hot_label))
  18. return standard_list
  19. def standardize_input(image):
  20. #Resize all images to be 32x32x3
  21. standard_im = cv2.resize(image,(32,32))
  22. return standard_im
  23. def one_hot_encode(label):
  24. #return the correct encoded label.
  25. '''
  26. # one_hot_encode("red") should return: [1, 0, 0]
  27. # one_hot_encode("yellow") should return: [0, 1, 0]
  28. # one_hot_encode("green") should return: [0, 0, 1]
  29. '''
  30. if label=='red':
  31. return [1,0,0]
  32. elif label=='yellow':
  33. return [0,1,0]
  34. else:
  35. return [0,0,1]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

Test your code

实现完了上述标准化代码后,我们需要进一步确定我们的代码是正确的,因此接下来我们可以实现一个函数来实现上述代码功能的检验 
用Python搭建自动化测试框架,我们需要组织用例以及测试执行,这里我们推荐Python的标准库——unittest。

  1. import unittest
  2. from IPython.display import Markdown,display
  3. # Helper function for printing markdown text(text in color/bold/etc)
  4. def printmd(string):
  5. display(Markdown(string))
  6. # Print a test falied message,given an error
  7. def print_fail():
  8. printmd('<span style=="color: red;">Test Failed</span>')
  9. def print_pass():
  10. printmd('<span style="color:green;">Test Passed</span>')
  11. # A class holding all tests
  12. class Tests(unittest.TestCase):
  13. #Tests the 'one_hot_encode' function,which is passed in as an argument
  14. def test_one_hot(self,one_hot_function):
  15. #test that the generate onr-hot lables match the expected one-hot label
  16. #for all three cases(red,yellow,green)
  17. try:
  18. self.assertEqual([1,0,0],one_hot_function('red'))
  19. self.assertEqual([0,1,0],one_hot_function('yellow'))
  20. self.assertEqual([0,0,1],one_hot_function('green'))
  21. #enter exception
  22. except self.failureException as e:
  23. #print out an error message
  24. print_fail()
  25. print('Your function did not return the excepted one-hot label')
  26. print('\n'+str(e))
  27. return
  28. print_pass()
  29. #Test if ay misclassified images are red but mistakenly classifed as green
  30. def test_red_aa_green(self,misclassified_images):
  31. #Loop through each misclassified image and the labels
  32. for im,predicted_label,true_label in misclassified_images:
  33. #check if the iamge is one of a red light
  34. if(true_label==[1,0,0]):
  35. try:
  36. self.assertNotEqual(true_label,[0,1,0])
  37. except self.failureException as e:
  38. print_fail()
  39. print('Warning:A red light is classified as green.')
  40. print('\n'+str(e))
  41. return
  42. print_pass()
  43. tests = Tests()
  44. tests.test_one_hot(one_hot_encode)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

Test Passed

Standardized_Train_List = standardize(IMAGE_LIST)
  • 1

Feature Extraction

在这里我们将使用色彩空间、形状分析及特征构造

RGB to HSV

  1. #Visualize
  2. image_num = 0
  3. test_im = Standardized_Train_List[image_num][0]
  4. test_label = Standardized_Train_List[image_num][1]
  5. #convert to hsv
  6. hsv = cv2.cvtColor(test_im, cv2.COLOR_RGB2HSV)
  7. # Print image label
  8. print('Label [red, yellow, green]: ' + str(test_label))
  9. h = hsv[:,:,0]
  10. s = hsv[:,:,1]
  11. v = hsv[:,:,2]
  12. # Plot the original image and the three channels
  1. , ax = plt.subplots(1, 4, figsize=(20,10))
  2. ax[0].settitle('Standardized image')
  3. ax[0].imshow(test_im)
  4. ax[1].set_title('H channel')
  5. ax[1].imshow(h, cmap='gray')
  6. ax[2].set_title('S channel')
  7. ax[2].imshow(s, cmap='gray')
  8. ax[3].set_title('V channel')
  9. ax[3].imshow(v, cmap='gray')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  1. Label [red, yellow, green]: [1, 0, 0]
  2. <matplotlib.image.AxesImage at 0x7fb49ad71f28>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

png

  1. # create feature
  2. '''
  3. HSV即色相、饱和度、明度(英语:Hue, Saturation, Value),又称HSB,其中B即英语:Brightness。
  4. 色相(H)是色彩的基本属性,就是平常所说的颜色名称,如红色、黄色等。
  5. 饱和度(S)是指色彩的纯度,越高色彩越纯,低则逐渐变灰,取0-100%的数值。
  6. 明度(V),亮度(L),取0-100%。
  7. '''
  8. def create_feature(rgb_image):
  9. '''
  10. Basic brightness feature
  11. rgb_image : a rgb_image
  12. '''
  13. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  14. sum_brightness = np.sum(hsv[:,:,2])
  15. area = 3232
  16. avg_brightness = sum_brightness / area#Find the average
  17. return avg_brightness
  18. def high_saturation_pixels(rgb_image,threshold=80):
  19. '''
  20. Returns average red and green content from high saturation pixels
  21. Usually, the traffic light contained the highest saturation pixels in the image.
  22. The threshold was experimentally determined to be 80
  23. '''
  24. high_sat_pixels = []
  25. hsv = cv2.cvtColor(rgb,cv2.COLOR_RGB2HSV)
  26. for i in range(32):
  27. for j in range(32):
  28. if hsv[i][j][1] > threshold:
  29. high_sat_pixels.append(rgb_image[i][j])
  30. if not high_sat_pixels:
  31. return highest_sat_pixel(rgb_image)
  32. sum_red = 0
  33. sum_green = 0
  34. for pixel in high_sat_pixels:
  35. sum_red+=pixel[0]
  36. sum_green+=pixel[1]
  37. # use sum() instead of manually adding them up
  38. avg_red = sum_red / len(high_sat_pixels)
  39. avg_green = sum_green / len(high_sat_pixels)0.8
  40. return avg_red,avg_green
  41. def highest_sat_pixel(rgb_image):
  42. '''
  43. Finds the highest saturation pixels, and checks if it has a higher green
  44. or a higher red content
  45. '''
  46. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  47. s = hsv[:,:,1]
  48. x,y = (np.unravel_index(np.argmax(s),s.shape))
  49. if rgb_image[x,y,0] > rgb_image[x,y,1]*0.9:
  50. return 1,0 #red has a higher content
  51. return 0,1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

Test dataset

接下来我们导入测试集来看看,上述方法的测试精度 
上述方法我们实现了: 
1.求平均的brightness 
2.求red及green的色彩饱和度 
有人或许会提出疑问,为啥没有进行yellow的判断,因此我们作出以下的改善 
reference url

这里部分阈值,我们直接参考WIKI上的数据: 
这里写图片描述

  1. def estimate_label(rgb_image,display=False):
  2. '''
  3. rgb_image:Standardized RGB image
  4. '''
  5. return red_green_yellow(rgb_image,display)
  6. def findNoneZero(rgb_image):
  7. rows,cols,
  1. = rgbimage.shape
  2. counter = 0
  3. for row in range(rows):
  4. for col in range(cols):
  5. pixels = rgb_image[row,col]
  6. if sum(pixels)!=0:
  7. counter = counter+1
  8. return counter
  9. def red_green_yellow(rgb_image,display):
  10. '''
  11. Determines the red , green and yellow content in each image using HSV and experimentally
  12. determined thresholds. Returns a Classification based on the values
  13. '''
  14. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  15. sum_saturation = np.sum(hsv[:,:,1])# Sum the brightness values
  16. area = 3232
  17. avg_saturation = sum_saturation / area #find average
  18. sat_low = int(avg_saturation1.3)#均值的1.3倍,工程经验
  19. val_low = 140
  20. #Green
  21. lower_green = np.array([70,sat_low,val_low])
  22. upper_green = np.array([100,255,255])
  23. green_mask = cv2.inRange(hsv,lower_green,upper_green)
  24. green_result = cv2.bitwise_and(rgb_image,rgb_image,mask = green_mask)
  25. #Yellow
  26. lower_yellow = np.array([10,sat_low,val_low])
  27. upper_yellow = np.array([60,255,255])
  28. yellow_mask = cv2.inRange(hsv,lower_yellow,upper_yellow)
  29. yellow_result = cv2.bitwise_and(rgb_image,rgb_image,mask=yellow_mask)
  30. # Red
  31. lower_red = np.array([150,sat_low,val_low])
  32. upper_red = np.array([180,255,255])
  33. red_mask = cv2.inRange(hsv,lower_red,upper_red)
  34. red_result = cv2.bitwise_and(rgb_image,rgb_image,mask = red_mask)
  35. if display==True:
  36. ,ax = plt.subplots(1,5,figsize=(20,10))
  37. ax[0].set_title('rgb image')
  38. ax[0].imshow(rgb_image)
  39. ax[1].set_title('red result')
  40. ax[1].imshow(red_result)
  41. ax[2].set_title('yellow result')
  42. ax[2].imshow(yellow_result)
  43. ax[3].set_title('green result')
  44. ax[3].imshow(green_result)
  45. ax[4].set_title('hsv image')
  46. ax[4].imshow(hsv)
  47. plt.show()
  48. sum_green = findNoneZero(green_result)
  49. sum_red = findNoneZero(red_result)
  50. sum_yellow = findNoneZero(yellow_result)
  51. if sum_red >= sum_yellow and sum_red>=sum_green:
  52. return [1,0,0]#Red
  53. if sum_yellow>=sum_green:
  54. return [0,1,0]#yellow
  55. return [0,0,1]#green
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

Test

接下来我们选择三张图片来看看测试效果

img_red,img_yellow,img_green

  1. img_test = [(img_red,'red'),(img_yellow,'yellow'),(img_green,'green')]
  2. standardtest = standardize(img_test)
  3. for img in standardtest:
  4. predicted_label = estimate_label(img[0],display = True)
  5. print('Predict label :',predicted_label)
  6. print('True label:',img[1])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

png

  1. Predict label : [1, 0, 0]
  2. True label: [1, 0, 0]
  • 1
  • 2
  • 3

png

  1. Predict label : [0, 1, 0]
  2. True label: [0, 1, 0]
  • 1
  • 2
  • 3

png

  1. Predict label : [0, 0, 1]
  2. True label: [0, 0, 1]
  • 1
  • 2
  • 3
  1. # Using the load_dataset function in helpers.py
  2. # Load test data
  3. TEST_IMAGE_LIST = load_dataset(IMAGE_DIR_TEST)
  4. # Standardize the test data
  5. STANDARDIZED_TEST_LIST = standardize(TEST_IMAGE_LIST)
  6. # Shuffle the standardized test data
  7. random.shuffle(STANDARDIZED_TEST_LIST)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  1. 181
  2. 9
  3. 107
  • 1
  • 2
  • 3
  • 4

Determine the Accuracy

接下来我们来看看咱们算法在测试集上的准确率。下面我们实现的代码存储所有的被错分的图片以及它们被预测的结果及真实标签。 
这些数据被存储在MISCLASSIFIED.

  1. # COnstructs a list of misclassfied iamges given a list of test images and their labels
  2. # This will throw an assertionerror if labels are not standardized(one hot encode)
  3. def get_misclassified_images(test_images,display=False):
  4. misclassified_images_labels = []
  5. #Iterate through all the test images
  6. #Classify each image and compare to the true label
  7. for image in test_images:
  8. # Get true data
  9. im = image[0]
  10. true_label = image[1]
  11. assert (len(true_label)==3),'This true_label is not the excepted length (3).'
  12. #Get predicted label from your classifier
  13. predicted_label = estimate_label(im,display=False)
  14. assert(len(predicted_label)==3),'This predicted_label is not the excepted length (3).'
  15. #compare true and predicted labels
  16. if(predicted_label!=true_label):
  17. #if these labels are ot equal, the image has been misclassified
  18. misclassified_images_labels.append((im,predicted_label,true_label))
  19. # return the list of misclassified [image,predicted_label,true_label] values
  20. return misclassified_images_labels
  21. # Find all misclassified images in a given test set
  22. MISCLASSIFIED = get_misclassified_images(STANDARDIZED_TEST_LIST,display=False)
  23. #Accuracy calcuations
  24. total = len(STANDARDIZED_TEST_LIST)
  25. num_correct = total-len(MISCLASSIFIED)
  26. accuracy = num_correct / total
  27. print('Accuracy:'+str(accuracy))
  28. print('Number of misclassfied images = '+str(len(MISCLASSIFIED))+' out of '+str(total))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  1. Accuracy:0.9797979797979798
  2. Number of misclassfied images = 6 out of 297
  • 1
  • 2
  • 3
posted @ 2019-08-07 20:07  core!  阅读(5669)  评论(1编辑  收藏  举报