会飞的蝌蚪君

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

参考链接:http://blog.csdn.net/c406495762/article/details/75172850

 

 1 import numpy as np
 2 import operator
 3 
 4 
 5 def creatDataset():
 6     group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
 7     labels = ['A','A','B','B']
 8     return group,labels
 9 
10 
11 def knn(inx,dataset,labels,k):
12     datasetsize  = dataset.shape[0]
13     diffmat = np.tile(inx,(datasetsize,1))-dataset
14     #print(diffmat)
15     sqdiffmat = diffmat**2
16     sqdistances = sqdiffmat.sum(1)
17     distances = sqdistances**0.5
18     sorteddistindices = distances.argsort()
19     #print(sorteddistindices)
20     classcount = {}
21     for i in range(k):
22         voteilabel = labels[sorteddistindices[i]]
23         #print(voteilabel)
24         classcount[voteilabel] = classcount.get(voteilabel, 0) + 1
25         #print(classcount[voteilabel])
26     #print(classcount)
27     sortedClassCount = sorted(classcount.items(), key=operator.itemgetter(1), reverse=True)
28     #print(sortedClassCount)
29     return sortedClassCount[0][0]
30 
31 
32 if __name__ == '__main__':
33     #创建数据集
34     group, labels = creatDataset()
35     #测试集
36     test = [1.2,1.8]
37     #kNN分类
38     test_class = knn(test, group, labels, 3)
39     #打印分类结果
40     print(test_class)

 

 

根据三个特征判断对人的喜好

  1 import numpy as np
  2 import matplotlib.pyplot as plt
  3 from matplotlib.font_manager import FontProperties
  4 import matplotlib.lines as mlines
  5 from pylab import *
  6 import operator
  7 #mpl.rcParams['font.sans-serif'] = ['SimHei'] #指定默认字体
  8 #mpl.rcParams['axes.unicode_minus'] = False #解决保存图像是负号'-'显示为方块的问题
  9 
 10 def file_matrix():
 11     """
 12     处理数据格式问题
 13     :return: 特征矩阵,标签矩阵
 14     """
 15     fr = open(r'C:\Users\TuZhiqiang\Desktop\MLIA\Ch02\datingTestSet.txt')
 16     array_of_lines = fr.readlines()
 17     number_of_lines = len(array_of_lines)
 18     return_matrix = np.zeros((number_of_lines,3))
 19     class_lable_vector = []
 20     index = 0
 21     for line in array_of_lines:
 22         line = line.strip()
 23         list_from_line = line.split('\t')
 24         return_matrix[index,:] = list_from_line[0:3]
 25         if list_from_line[-1] == 'didntLike':
 26             class_lable_vector.append(1)
 27         elif list_from_line[-1] == 'smallDoses':
 28             class_lable_vector.append(2)
 29         elif list_from_line[-1] == 'largeDoses':
 30             class_lable_vector.append(3)
 31         index +=1
 32     #print(class_lable_vector)
 33     return return_matrix,class_lable_vector
 34 
 35 
 36 
 37 def showdatas(dating_data_matrix,dating_lables):
 38     """
 39     数据可视化
 40     :param dating_data_matrix: 特征矩阵
 41     :param dating_lables: 标签矩阵
 42     :return: 特征之间的关系图像
 43     """
 44     font = FontProperties(fname=r"C:\WINDOWS\Fonts\simsun.ttc", size=14)
 45     fig , axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(13, 8))  #画多图
 46     number_of_lables = len(dating_lables)
 47     labels_colors = []
 48     for i in dating_lables:
 49         if i == 1:
 50             labels_colors.append('black')
 51         if i == 2:
 52             labels_colors.append('orange')
 53         if i == 3:
 54             labels_colors.append('red')
 55     #print(labels_colors)
 56 
 57     # 每年获得的飞行常客里程数   与   玩视频游戏所消耗时间占比
 58     axs[0][0].scatter(x = dating_data_matrix[:,0], y=dating_data_matrix[:,1], color=labels_colors,s=15, alpha=0.5)
 59     axs0_title_text = axs[0][0].set_title(u'每年获得的飞行常客里程数与玩视频游戏所消耗时间占比',fontproperties=font)
 60     axs0_xlabel_text = axs[0][0].set_xlabel(u'每年获得的飞行常客里程数',fontproperties=font)
 61     axs0_ylabel_text = axs[0][0].set_ylabel(u'玩视频游戏所消耗时间占比',fontproperties=font)
 62     plt.setp(axs0_title_text, size=9, weight='bold', color='green')
 63     plt.setp(axs0_xlabel_text, size=7, weight='bold', color='black')
 64     plt.setp(axs0_ylabel_text, size=7, weight='bold', color='black')
 65 
 66     #每年获得的飞行常客里程数   与   每周消费冰激凌公升数
 67     axs[0][1].scatter(x = dating_data_matrix[:,0], y=dating_data_matrix[:,2], color=labels_colors,s=15, alpha=0.5)
 68     axs1_title_text = axs[0][1].set_title(u'每年获得的飞行常客里程数与每周消费冰激凌公升数',fontproperties=font)
 69     axs1_xlabel_text = axs[0][1].set_xlabel(u'每年获得的飞行常客里程数',fontproperties=font)
 70     axs1_ylabel_text = axs[0][1].set_ylabel(u'每周消费冰激凌公升数',fontproperties=font)
 71     plt.setp(axs1_title_text, size=9, weight='bold', color='green')
 72     plt.setp(axs1_xlabel_text, size=7, weight='bold', color='black')
 73     plt.setp(axs1_ylabel_text, size=7, weight='bold', color='black')
 74 
 75 
 76     #玩视频游戏所消耗时间占比    与   每周消费冰激凌公升数
 77     axs[1][0].scatter(x = dating_data_matrix[:,1], y=dating_data_matrix[:,2], color=labels_colors,s=15,alpha=0.5)
 78     axs2_title_text = axs[1][0].set_title(u'玩视频游戏所消耗时间占比与每周消费冰激凌公升数',fontproperties=font)
 79     axs2_xlabel_text = axs[1][0].set_xlabel(u'玩视频游戏所消耗时间占比',fontproperties=font)
 80     axs2_ylabel_text = axs[1][0].set_ylabel(u'每周消费冰激凌公升数',fontproperties=font)
 81     plt.setp(axs2_title_text, size=9, weight='bold', color='green')
 82     plt.setp(axs2_xlabel_text, size=7, weight='bold', color='black')
 83     plt.setp(axs2_ylabel_text, size=7, weight='bold', color='black')
 84 
 85 
 86     #设置图例
 87     didntLike = mlines.Line2D([], [], color='black', marker='.',markersize=6, label='didntLike')
 88     smallDoses = mlines.Line2D([], [], color='orange', marker='.',markersize=6, label='smallDoses')
 89     largeDoses = mlines.Line2D([], [], color='red', marker='.',markersize=6, label='largeDoses')
 90 
 91 
 92     #添加图例
 93     axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses])
 94     axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses])
 95     axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses])
 96 
 97     plt.show()
 98 
 99 
100 
101 def auto_norm(data_set):
102     """
103     #数据归一化
104     :param data_set:
105     :return:
106     """
107     min_vals = data_set.min(0)
108     #print(min_vals)
109     max_vals = data_set.max(0)
110     #print(max_vals)
111     ranges = max_vals - min_vals
112     norm_data_set = np.zeros(np.shape(data_set))
113     m = data_set.shape[0]
114     norm_data_set = data_set - np.tile(min_vals,(m,1))
115     norm_data_set = norm_data_set/np.tile(ranges,(m,1))
116     return norm_data_set,ranges,min_vals
117 
118 
119 def classifier_calculator(inx,data_set,labels,k):
120     data_set_size  = data_set.shape[0]
121     diff_mat = np.tile(inx,(data_set_size,1))-data_set
122     #print(diffmat)
123     sq_diff_mat = diff_mat**2
124     sq_distances = sq_diff_mat.sum(1)
125     distances = sq_distances**0.5
126     sorted_dist_indices = distances.argsort()
127     #print(sorteddistindices)
128     class_count = {}
129     for i in range(k):
130         vote_i_label = labels[sorted_dist_indices[i]]
131         #print(voteilabel)
132         class_count[vote_i_label] = class_count.get(vote_i_label, 0) + 1
133         #print(classcount[voteilabel])
134     #print(classcount)
135     sorted_Class_Count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
136     #print(sortedClassCount)
137     return sorted_Class_Count[0][0]
138 
139 
140 
141 
142 
143 # def dataing_class_test():
144 #     dating_data_matrix, dating_lables = file_matrix()
145 #     abstruct_ratio = 0.10
146 #     norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix)
147 #     n = norm_data_set.shape[0]
148 #     num_test_data_set = int(n * abstruct_ratio)
149 #     error_count = 0.0
150 #     for i in range(num_test_data_set):
151 #         classifier_result = classifier_calculator(norm_data_set[i,:],norm_data_set[num_test_data_set:n,:],dating_lables[num_test_data_set:n],4)
152 #         print("分类结果:%s \t 真实类别:%s"   % (classifier_result,dating_lables[i]))
153 #         if classifier_result != dating_lables[i]:
154 #             error_count += 1.0
155 #     print("错误率:%f%%" %(error_count/float(num_test_data_set)*100))
156 #
157 
158 
159 def  classify_person():
160     result_list = ["讨厌","有些喜欢","非常喜欢"]
161     game_time = float(input("玩视频游戏所耗时间百分比:"))
162     fly_miles = float(input("每年获得的飞行常客里程数:"))
163     ice_cream = float(input("每周消费的冰激淋公升数:"))
164     dating_data_matrix, dating_lables = file_matrix()
165     norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix)
166     in_arr = np.array([game_time, fly_miles,ice_cream])
167     #记住测试集也要进行归一化操作!
168     norm_in_arr =(in_arr - min_vals)/ranges
169     classifier_result = classifier_calculator(norm_in_arr,norm_data_set,dating_lables,5)
170     print("你可能%s这个人" % ( result_list[classifier_result-1]))
171 
172 
173 
174 
175 
176 
177 
178 
179 
180 
181 
182 
183 
184 
185 
186 
187 
188 if __name__ == '__main__':
189 
190     #dating_data_matrix, dating_lables = file_matrix()
191     #showdatas(dating_data_matrix, dating_lables)
192     #norm_data_set, ranges, min_vals = auto_norm(dating_data_matrix)
193     #print( norm_data_set)
194     #print(ranges)
195     #print(min_vals)
196     #dataing_class_test()
197     classify_person()

 

posted on 2018-03-05 22:34  会飞的蝌蚪  阅读(252)  评论(0)    收藏  举报