参考链接:http://blog.csdn.net/c406495762/article/details/75172850
1 import numpy as np 2 import operator 3 4 5 def creatDataset(): 6 group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 7 labels = ['A','A','B','B'] 8 return group,labels 9 10 11 def knn(inx,dataset,labels,k): 12 datasetsize = dataset.shape[0] 13 diffmat = np.tile(inx,(datasetsize,1))-dataset 14 #print(diffmat) 15 sqdiffmat = diffmat**2 16 sqdistances = sqdiffmat.sum(1) 17 distances = sqdistances**0.5 18 sorteddistindices = distances.argsort() 19 #print(sorteddistindices) 20 classcount = {} 21 for i in range(k): 22 voteilabel = labels[sorteddistindices[i]] 23 #print(voteilabel) 24 classcount[voteilabel] = classcount.get(voteilabel, 0) + 1 25 #print(classcount[voteilabel]) 26 #print(classcount) 27 sortedClassCount = sorted(classcount.items(), key=operator.itemgetter(1), reverse=True) 28 #print(sortedClassCount) 29 return sortedClassCount[0][0] 30 31 32 if __name__ == '__main__': 33 #创建数据集 34 group, labels = creatDataset() 35 #测试集 36 test = [1.2,1.8] 37 #kNN分类 38 test_class = knn(test, group, labels, 3) 39 #打印分类结果 40 print(test_class)
根据三个特征判断对人的喜好
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.font_manager import FontProperties 4 import matplotlib.lines as mlines 5 from pylab import * 6 import operator 7 #mpl.rcParams['font.sans-serif'] = ['SimHei'] #指定默认字体 8 #mpl.rcParams['axes.unicode_minus'] = False #解决保存图像是负号'-'显示为方块的问题 9 10 def file_matrix(): 11 """ 12 处理数据格式问题 13 :return: 特征矩阵,标签矩阵 14 """ 15 fr = open(r'C:\Users\TuZhiqiang\Desktop\MLIA\Ch02\datingTestSet.txt') 16 array_of_lines = fr.readlines() 17 number_of_lines = len(array_of_lines) 18 return_matrix = np.zeros((number_of_lines,3)) 19 class_lable_vector = [] 20 index = 0 21 for line in array_of_lines: 22 line = line.strip() 23 list_from_line = line.split('\t') 24 return_matrix[index,:] = list_from_line[0:3] 25 if list_from_line[-1] == 'didntLike': 26 class_lable_vector.append(1) 27 elif list_from_line[-1] == 'smallDoses': 28 class_lable_vector.append(2) 29 elif list_from_line[-1] == 'largeDoses': 30 class_lable_vector.append(3) 31 index +=1 32 #print(class_lable_vector) 33 return return_matrix,class_lable_vector 34 35 36 37 def showdatas(dating_data_matrix,dating_lables): 38 """ 39 数据可视化 40 :param dating_data_matrix: 特征矩阵 41 :param dating_lables: 标签矩阵 42 :return: 特征之间的关系图像 43 """ 44 font = FontProperties(fname=r"C:\WINDOWS\Fonts\simsun.ttc", size=14) 45 fig , axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(13, 8)) #画多图 46 number_of_lables = len(dating_lables) 47 labels_colors = [] 48 for i in dating_lables: 49 if i == 1: 50 labels_colors.append('black') 51 if i == 2: 52 labels_colors.append('orange') 53 if i == 3: 54 labels_colors.append('red') 55 #print(labels_colors) 56 57 # 每年获得的飞行常客里程数 与 玩视频游戏所消耗时间占比 58 axs[0][0].scatter(x = dating_data_matrix[:,0], y=dating_data_matrix[:,1], color=labels_colors,s=15, alpha=0.5) 59 axs0_title_text = axs[0][0].set_title(u'每年获得的飞行常客里程数与玩视频游戏所消耗时间占比',fontproperties=font) 60 axs0_xlabel_text = axs[0][0].set_xlabel(u'每年获得的飞行常客里程数',fontproperties=font) 61 axs0_ylabel_text = axs[0][0].set_ylabel(u'玩视频游戏所消耗时间占比',fontproperties=font) 62 plt.setp(axs0_title_text, size=9, weight='bold', color='green') 63 plt.setp(axs0_xlabel_text, size=7, weight='bold', color='black') 64 plt.setp(axs0_ylabel_text, size=7, weight='bold', color='black') 65 66 #每年获得的飞行常客里程数 与 每周消费冰激凌公升数 67 axs[0][1].scatter(x = dating_data_matrix[:,0], y=dating_data_matrix[:,2], color=labels_colors,s=15, alpha=0.5) 68 axs1_title_text = axs[0][1].set_title(u'每年获得的飞行常客里程数与每周消费冰激凌公升数',fontproperties=font) 69 axs1_xlabel_text = axs[0][1].set_xlabel(u'每年获得的飞行常客里程数',fontproperties=font) 70 axs1_ylabel_text = axs[0][1].set_ylabel(u'每周消费冰激凌公升数',fontproperties=font) 71 plt.setp(axs1_title_text, size=9, weight='bold', color='green') 72 plt.setp(axs1_xlabel_text, size=7, weight='bold', color='black') 73 plt.setp(axs1_ylabel_text, size=7, weight='bold', color='black') 74 75 76 #玩视频游戏所消耗时间占比 与 每周消费冰激凌公升数 77 axs[1][0].scatter(x = dating_data_matrix[:,1], y=dating_data_matrix[:,2], color=labels_colors,s=15,alpha=0.5) 78 axs2_title_text = axs[1][0].set_title(u'玩视频游戏所消耗时间占比与每周消费冰激凌公升数',fontproperties=font) 79 axs2_xlabel_text = axs[1][0].set_xlabel(u'玩视频游戏所消耗时间占比',fontproperties=font) 80 axs2_ylabel_text = axs[1][0].set_ylabel(u'每周消费冰激凌公升数',fontproperties=font) 81 plt.setp(axs2_title_text, size=9, weight='bold', color='green') 82 plt.setp(axs2_xlabel_text, size=7, weight='bold', color='black') 83 plt.setp(axs2_ylabel_text, size=7, weight='bold', color='black') 84 85 86 #设置图例 87 didntLike = mlines.Line2D([], [], color='black', marker='.',markersize=6, label='didntLike') 88 smallDoses = mlines.Line2D([], [], color='orange', marker='.',markersize=6, label='smallDoses') 89 largeDoses = mlines.Line2D([], [], color='red', marker='.',markersize=6, label='largeDoses') 90 91 92 #添加图例 93 axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses]) 94 axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses]) 95 axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses]) 96 97 plt.show() 98 99 100 101 def auto_norm(data_set): 102 """ 103 #数据归一化 104 :param data_set: 105 :return: 106 """ 107 min_vals = data_set.min(0) 108 #print(min_vals) 109 max_vals = data_set.max(0) 110 #print(max_vals) 111 ranges = max_vals - min_vals 112 norm_data_set = np.zeros(np.shape(data_set)) 113 m = data_set.shape[0] 114 norm_data_set = data_set - np.tile(min_vals,(m,1)) 115 norm_data_set = norm_data_set/np.tile(ranges,(m,1)) 116 return norm_data_set,ranges,min_vals 117 118 119 def classifier_calculator(inx,data_set,labels,k): 120 data_set_size = data_set.shape[0] 121 diff_mat = np.tile(inx,(data_set_size,1))-data_set 122 #print(diffmat) 123 sq_diff_mat = diff_mat**2 124 sq_distances = sq_diff_mat.sum(1) 125 distances = sq_distances**0.5 126 sorted_dist_indices = distances.argsort() 127 #print(sorteddistindices) 128 class_count = {} 129 for i in range(k): 130 vote_i_label = labels[sorted_dist_indices[i]] 131 #print(voteilabel) 132 class_count[vote_i_label] = class_count.get(vote_i_label, 0) + 1 133 #print(classcount[voteilabel]) 134 #print(classcount) 135 sorted_Class_Count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True) 136 #print(sortedClassCount) 137 return sorted_Class_Count[0][0] 138 139 140 141 142 143 # def dataing_class_test(): 144 # dating_data_matrix, dating_lables = file_matrix() 145 # abstruct_ratio = 0.10 146 # norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix) 147 # n = norm_data_set.shape[0] 148 # num_test_data_set = int(n * abstruct_ratio) 149 # error_count = 0.0 150 # for i in range(num_test_data_set): 151 # classifier_result = classifier_calculator(norm_data_set[i,:],norm_data_set[num_test_data_set:n,:],dating_lables[num_test_data_set:n],4) 152 # print("分类结果:%s \t 真实类别:%s" % (classifier_result,dating_lables[i])) 153 # if classifier_result != dating_lables[i]: 154 # error_count += 1.0 155 # print("错误率:%f%%" %(error_count/float(num_test_data_set)*100)) 156 # 157 158 159 def classify_person(): 160 result_list = ["讨厌","有些喜欢","非常喜欢"] 161 game_time = float(input("玩视频游戏所耗时间百分比:")) 162 fly_miles = float(input("每年获得的飞行常客里程数:")) 163 ice_cream = float(input("每周消费的冰激淋公升数:")) 164 dating_data_matrix, dating_lables = file_matrix() 165 norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix) 166 in_arr = np.array([game_time, fly_miles,ice_cream]) 167 #记住测试集也要进行归一化操作! 168 norm_in_arr =(in_arr - min_vals)/ranges 169 classifier_result = classifier_calculator(norm_in_arr,norm_data_set,dating_lables,5) 170 print("你可能%s这个人" % ( result_list[classifier_result-1])) 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 if __name__ == '__main__': 189 190 #dating_data_matrix, dating_lables = file_matrix() 191 #showdatas(dating_data_matrix, dating_lables) 192 #norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix) 193 #print( norm_data_set) 194 #print(ranges) 195 #print(min_vals) 196 #dataing_class_test() 197 classify_person()
浙公网安备 33010602011771号