初识机器学习-K邻近算法

Posted on 2017-05-05 00:17  Y_Mathison  阅读(285)  评论(1)    收藏  举报
 1 '''
 2 Created on Oct 6, 2010
 3 
 4 @author: Peter
 5 '''
 6 from numpy import *
 7 import matplotlib
 8 import matplotlib.pyplot as plt
 9 from matplotlib.patches import Rectangle
10 
11 
12 n = 1000 #number of points to create
13 xcord = zeros((n))
14 ycord = zeros((n))
15 markers =[]
16 colors =[]
17 fw = open('testSet.txt','w')
18 for i in range(n):
19     [r0,r1] = random.standard_normal(2)
20     myClass = random.uniform(0,1)
21     if (myClass <= 0.16):
22         fFlyer = random.uniform(22000, 60000)
23         tats = 3 + 1.6*r1
24         markers.append(20)
25         colors.append(2.1)
26         classLabel = 1 #'didntLike'
27         print ("%d, %f, class1") % (fFlyer, tats)
28     elif ((myClass > 0.16) and (myClass <= 0.33)):
29         fFlyer = 6000*r0 + 70000
30         tats = 10 + 3*r1 + 2*r0
31         markers.append(20)
32         colors.append(1.1)
33         classLabel = 1 #'didntLike'
34         print ("%d, %f, class1") % (fFlyer, tats)
35     elif ((myClass > 0.33) and (myClass <= 0.66)):
36         fFlyer = 5000*r0 + 10000
37         tats = 3 + 2.8*r1
38         markers.append(30)
39         colors.append(1.1)
40         classLabel = 2 #'smallDoses'
41         print ("%d, %f, class2") % (fFlyer, tats)
42     else:
43         fFlyer = 10000*r0 + 35000
44         tats = 10 + 2.0*r1
45         markers.append(50)
46         colors.append(0.1)
47         classLabel = 3 #'largeDoses'
48         print ("%d, %f, class3") % (fFlyer, tats)
49     if (tats < 0): tats =0
50     if (fFlyer < 0): fFlyer =0
51     xcord[i] = fFlyer; ycord[i]=tats
52     fw.write("%d\t%f\t%f\t%d\n" % (fFlyer, tats, random.uniform(0.0, 1.7), classLabel))
53 
54 fw.close()
55 fig = plt.figure()
56 ax = fig.add_subplot(111)
57 ax.scatter(xcord,ycord, c=colors, s=markers)
58 type1 = ax.scatter([-10], [-10], s=20, c='red')
59 type2 = ax.scatter([-10], [-15], s=30, c='green')
60 type3 = ax.scatter([-10], [-20], s=50, c='blue')
61 ax.legend([type1, type2, type3], ["Class 1", "Class 2", "Class 3"], loc=2)
62 #ax.axis([-5000,100000,-2,25])
63 plt.xlabel('Frequent Flyier Miles Earned Per Year')
64 plt.ylabel('Percentage of Body Covered By Tatoos')
65 plt.show()
运行效果: