第4章下 最基础的分类算法-k近邻算法 kNN
4-6 网格搜索与k近邻算法中更多超参数

Notbook 示例

Notbook 源码
1 [1] 2 import numpy as np 3 from sklearn import datasets 4 [2] 5 digits = datasets.load_digits() 6 X = digits.data 7 y = digits.target 8 [3] 9 from sklearn.model_selection import train_test_split 10 11 X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=111 ) 12 [4] 13 from sklearn.neighbors import KNeighborsClassifier 14 15 knn_clf = KNeighborsClassifier( n_neighbors = 6 ) 16 knn_clf.fit(X_train,y_train) 17 knn_clf.score(X_test,y_test) 18 0.9833333333333333 19 Grid Search 20 [5] 21 param_gid = [ 22 { 23 'weights': ['unifrom'], 24 'n_neighbors': [ i for i in range(1,11)] 25 }, 26 { 27 'weights': ['distance'], 28 'n_neighbors': [ i for i in range(1,11)], 29 'p': [ i for i in range(1,6)] 30 } 31 32 ] 33 [6] 34 knn_clf = KNeighborsClassifier() 35 [7] 36 from sklearn.model_selection import GridSearchCV 37 38 grid_search = GridSearchCV(knn_clf,param_gid) 39 [8] 40 %%time 41 grid_search.fit(X_train,y_train) 42 CPU times: total: 2min 15s 43 Wall time: 2min 18s 44 45 F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 46 50 fits failed out of a total of 300. 47 The score on these train-test partitions for these parameters will be set to nan. 48 If these failures are not expected, you can try to debug them by setting error_score='raise'. 49 50 Below are more details about the failures: 51 -------------------------------------------------------------------------------- 52 50 fits failed with the following error: 53 Traceback (most recent call last): 54 File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score 55 estimator.fit(X_train, y_train, **fit_params) 56 File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit 57 self.weights = _check_weights(self.weights) 58 File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights 59 raise ValueError( 60 ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function 61 62 warnings.warn(some_fits_failed_message, FitFailedWarning) 63 F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan 64 nan nan nan nan 0.98011446 0.98965724 65 0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724 66 0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681 67 0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997 68 0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542 69 0.98488585 0.9856795 0.98885727 0.9856795 0.98090179 0.98329539 70 0.9856795 0.98806362 0.98488269 0.98010181 0.98170176 0.9856795 71 0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948 72 0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997] 73 warnings.warn( 74 75 GridSearchCV(estimator=KNeighborsClassifier(), 76 param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 77 'weights': ['unifrom']}, 78 {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 79 'p': [1, 2, 3, 4, 5], 'weights': ['distance']}]) 80 [9] 81 grid_search.best_estimator_ 82 KNeighborsClassifier(n_neighbors=1, p=4, weights='distance') 83 [10] 84 grid_search.best_score_ 85 0.9920445203313729 86 [11] 87 grid_search.best_params_ 88 {'n_neighbors': 1, 'p': 4, 'weights': 'distance'} 89 若 随机数为111,则 grid_search.best_score_ = 0.9920445203313729, grid_search.best_params_ = {'n_neighbors': 1, 'p': 4, 'weights': 'distance'} 90 91 [12] 92 knn_clf = grid_search.best_estimator_ 93 [13] 94 knn_clf.predict(X_test) 95 array([7, 1, 2, 9, 5, 8, 6, 4, 8, 3, 8, 4, 4, 5, 7, 1, 6, 1, 0, 6, 6, 8, 96 8, 0, 8, 4, 5, 8, 0, 0, 3, 3, 5, 2, 1, 4, 8, 6, 7, 3, 3, 9, 6, 0, 97 4, 9, 7, 3, 8, 7, 4, 3, 5, 0, 3, 1, 7, 6, 5, 7, 6, 0, 9, 7, 7, 8, 98 2, 8, 6, 6, 1, 1, 2, 6, 4, 6, 4, 8, 6, 9, 8, 1, 3, 4, 4, 2, 0, 7, 99 6, 0, 8, 2, 0, 5, 8, 5, 3, 3, 7, 4, 7, 3, 4, 2, 4, 9, 1, 8, 5, 1, 100 2, 7, 0, 2, 8, 9, 7, 5, 7, 7, 8, 8, 9, 2, 3, 9, 7, 7, 8, 2, 5, 3, 101 2, 4, 0, 1, 4, 8, 7, 9, 6, 8, 1, 5, 2, 6, 1, 4, 1, 6, 5, 3, 4, 2, 102 2, 7, 0, 7, 1, 5, 4, 6, 1, 7, 4, 9, 6, 8, 5, 8, 4, 3, 3, 2, 5, 6, 103 7, 9, 0, 2, 0, 5, 4, 8, 0, 8, 6, 9, 7, 3, 1, 9, 4, 2, 7, 9, 4, 0, 104 5, 2, 8, 2, 9, 1, 8, 5, 4, 5, 7, 7, 5, 5, 0, 1, 4, 4, 6, 5, 7, 6, 105 0, 6, 7, 1, 9, 0, 6, 1, 2, 9, 1, 5, 3, 0, 2, 1, 0, 9, 3, 4, 1, 0, 106 9, 9, 2, 0, 5, 3, 6, 5, 5, 3, 9, 1, 2, 8, 7, 4, 9, 8, 8, 1, 3, 1, 107 6, 3, 0, 7, 2, 4, 7, 2, 5, 0, 6, 4, 7, 4, 1, 0, 3, 1, 8, 0, 5, 6, 108 9, 5, 5, 0, 6, 0, 5, 2, 9, 7, 2, 9, 1, 0, 3, 5, 8, 8, 0, 4, 3, 4, 109 6, 1, 6, 1, 7, 3, 3, 2, 3, 6, 7, 1, 0, 1, 9, 6, 6, 6, 8, 2, 3, 5, 110 9, 4, 4, 5, 3, 9, 7, 1, 3, 0, 0, 8, 6, 9, 7, 9, 6, 4, 2, 7, 2, 6, 111 5, 4, 1, 7, 9, 0, 1, 1, 7, 5, 3, 3, 7, 4, 9, 0, 8, 6, 0, 9, 1, 9, 112 7, 8, 8, 8, 6, 2, 1, 3, 0, 2, 3, 6, 8, 1, 6, 1, 3, 9, 6, 2, 5, 2, 113 9, 7, 7, 6, 5, 8, 0, 1, 8, 6, 3, 5, 0, 4, 3, 9, 9, 3, 4, 3, 7, 9, 114 2, 3, 5, 3, 9, 3, 1, 4, 7, 7, 1, 7, 4, 3, 0, 8, 0, 9, 6, 3, 9, 8, 115 3, 9, 9, 9, 4, 1, 6, 7, 7, 2, 0, 1, 0, 7, 5, 7, 6, 1, 5, 0, 6, 9, 116 5, 1, 2, 1, 7, 5, 2, 1, 8, 1, 8, 8, 2, 8, 6, 8, 7, 0, 9, 9, 6, 2, 117 0, 9, 6, 3, 4, 3, 0, 8, 5, 4, 8, 6, 4, 5, 2, 5, 6, 1, 0, 5, 7, 0, 118 9, 5, 3, 2, 9, 3, 0, 6, 4, 8, 3, 2, 3, 6, 6, 8, 1, 9, 4, 3, 1, 1, 119 4, 5, 4, 3, 7, 5, 3, 3, 7, 8, 1, 0]) 120 [14] 121 knn_clf.score(X_test,y_test) 122 0.9907407407407407 123 [15] 124 %%time 125 grid_search = GridSearchCV(knn_clf,param_gid,n_jobs= 4, verbose = 2) 126 grid_search.fit(X_train,y_train) 127 # 创建多个分类器来比较,可以并行处理,n_jobs 为分配核的数量,默认为单核 1 .-1为全核。 128 # verbose,及时输出一些信息,值越大越详细 129 Fitting 5 folds for each of 60 candidates, totalling 300 fits 130 CPU times: total: 484 ms 131 Wall time: 1min 28s 132 133 F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 134 50 fits failed out of a total of 300. 135 The score on these train-test partitions for these parameters will be set to nan. 136 If these failures are not expected, you can try to debug them by setting error_score='raise'. 137 138 Below are more details about the failures: 139 -------------------------------------------------------------------------------- 140 50 fits failed with the following error: 141 Traceback (most recent call last): 142 File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score 143 estimator.fit(X_train, y_train, **fit_params) 144 File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit 145 self.weights = _check_weights(self.weights) 146 File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights 147 raise ValueError( 148 ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function 149 150 warnings.warn(some_fits_failed_message, FitFailedWarning) 151 F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan 152 nan nan nan nan 0.98011446 0.98965724 153 0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724 154 0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681 155 0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997 156 0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542 157 0.98488585 0.9856795 0.98885727 0.9856795 0.98090179 0.98329539 158 0.9856795 0.98806362 0.98488269 0.98010181 0.98170176 0.9856795 159 0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948 160 0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997] 161 warnings.warn( 162 163 GridSearchCV(estimator=KNeighborsClassifier(n_neighbors=1, p=4, 164 weights='distance'), 165 n_jobs=4, 166 param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 167 'weights': ['unifrom']}, 168 {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 169 'p': [1, 2, 3, 4, 5], 'weights': ['distance']}], 170 verbose=2)
4-7 数据归一化




Notbook 示例

Notbook 源码
1 数据归一化处理 2 [1] 3 import numpy as np 4 import matplotlib.pyplot as plt 5 最值归一化 Normalization 6 [2] 7 x = np.random.randint(0,100, size = 100) 8 [3] 9 x 10 array([ 2, 58, 55, 40, 68, 7, 72, 4, 50, 89, 96, 19, 71, 6, 41, 40, 63, 11 4, 2, 26, 79, 62, 62, 92, 49, 46, 75, 47, 44, 91, 40, 67, 38, 14, 12 13, 0, 93, 15, 20, 50, 38, 31, 41, 22, 36, 85, 64, 87, 98, 65, 31, 13 10, 46, 22, 86, 24, 68, 25, 11, 31, 22, 87, 84, 18, 58, 87, 4, 15, 14 64, 92, 70, 59, 74, 81, 47, 33, 1, 93, 6, 37, 62, 17, 58, 56, 98, 15 53, 2, 70, 36, 17, 21, 66, 1, 79, 60, 71, 89, 71, 61, 30]) 16 [5] 17 ( x - np.min(x)) / (np.max(x) - np.min(x)) 18 array([0.02040816, 0.59183673, 0.56122449, 0.40816327, 0.69387755, 19 0.07142857, 0.73469388, 0.04081633, 0.51020408, 0.90816327, 20 0.97959184, 0.19387755, 0.7244898 , 0.06122449, 0.41836735, 21 0.40816327, 0.64285714, 0.04081633, 0.02040816, 0.26530612, 22 0.80612245, 0.63265306, 0.63265306, 0.93877551, 0.5 , 23 0.46938776, 0.76530612, 0.47959184, 0.44897959, 0.92857143, 24 0.40816327, 0.68367347, 0.3877551 , 0.14285714, 0.13265306, 25 0. , 0.94897959, 0.15306122, 0.20408163, 0.51020408, 26 0.3877551 , 0.31632653, 0.41836735, 0.2244898 , 0.36734694, 27 0.86734694, 0.65306122, 0.8877551 , 1. , 0.66326531, 28 0.31632653, 0.10204082, 0.46938776, 0.2244898 , 0.87755102, 29 0.24489796, 0.69387755, 0.25510204, 0.1122449 , 0.31632653, 30 0.2244898 , 0.8877551 , 0.85714286, 0.18367347, 0.59183673, 31 0.8877551 , 0.04081633, 0.15306122, 0.65306122, 0.93877551, 32 0.71428571, 0.60204082, 0.75510204, 0.82653061, 0.47959184, 33 0.33673469, 0.01020408, 0.94897959, 0.06122449, 0.37755102, 34 0.63265306, 0.17346939, 0.59183673, 0.57142857, 1. , 35 0.54081633, 0.02040816, 0.71428571, 0.36734694, 0.17346939, 36 0.21428571, 0.67346939, 0.01020408, 0.80612245, 0.6122449 , 37 0.7244898 , 0.90816327, 0.7244898 , 0.62244898, 0.30612245]) 38 [6] 39 X = np.random.randint(0,100,(50,2)) 40 [7] 41 X[:10,:] 42 array([[19, 14], 43 [23, 82], 44 [ 4, 17], 45 [44, 58], 46 [23, 91], 47 [46, 17], 48 [34, 25], 49 [29, 39], 50 [69, 61], 51 [70, 25]]) 52 [9] 53 X = np.array(X,dtype = float) 54 [10] 55 X[:10,:] 56 array([[19., 14.], 57 [23., 82.], 58 [ 4., 17.], 59 [44., 58.], 60 [23., 91.], 61 [46., 17.], 62 [34., 25.], 63 [29., 39.], 64 [69., 61.], 65 [70., 25.]]) 66 [18] 67 X[:,0] = (X[:,0] - np.min(X[:,0])) / ( np.max(X[:,0]) - np.min(X[:,0])) 68 [19] 69 X[:,1] = (X[:,1] - np.min(X[:,1])) / ( np.max(X[:,1]) - np.min(X[:,1])) 70 [20] 71 X[:10,:] 72 array([[0.19191919, 0.11458333], 73 [0.23232323, 0.82291667], 74 [0.04040404, 0.14583333], 75 [0.44444444, 0.57291667], 76 [0.23232323, 0.91666667], 77 [0.46464646, 0.14583333], 78 [0.34343434, 0.22916667], 79 [0.29292929, 0.375 ], 80 [0.6969697 , 0.60416667], 81 [0.70707071, 0.22916667]]) 82 [21] 83 plt.scatter(X[:,0],X[:,1]) 84 <matplotlib.collections.PathCollection at 0x1f514d5c2e0> 85 86 [22] 87 np.mean(X[:,0]) 88 0.4503030303030303 89 [26] 90 np.std(X[:,0]) 91 0.32224653972392703 92 [24] 93 np.mean(X[:,1]) 94 0.4503030303030303 95 [27] 96 np.std(X[:,1]) 97 0.3004160887650475 98 均值方差归一化 99 [28] 100 X2 = np.random.randint(0,100,(50,2)) 101 [29] 102 X2 = np.array(X2,dtype = float) 103 [36] 104 X2[:,0] = (X2[:,0] - np.mean(X2[:,0])) / np.std(X2[:,0]) 105 [37] 106 X2[:,1] = (X2[:,1] - np.mean(X2[:,1])) / np.std(X2[:,1]) 107 [38] 108 plt.scatter(X2[:,0],X2[:,1]) 109 <matplotlib.collections.PathCollection at 0x1f517fa1430> 110 111 [39] 112 np.mean(X2[:,0]) 113 0.0 114 [40] 115 np.std(X2[:,0]) 116 1.0 117 [41] 118 np.mean(X2[:,1]) 119 -4.4408920985006264e-17 120 [42] 121 np.std(X2[:,1]) 122 1.0
4-8 scikit-learn中的Scaler




Notbook 示例

notbook 源码
1 [1] 2 import numpy as np 3 from sklearn import datasets 4 [2] 5 iris = datasets.load_iris() 6 [3] 7 X = iris.data 8 y = iris.target 9 [4] 10 X[:10,:] 11 array([[5.1, 3.5, 1.4, 0.2], 12 [4.9, 3. , 1.4, 0.2], 13 [4.7, 3.2, 1.3, 0.2], 14 [4.6, 3.1, 1.5, 0.2], 15 [5. , 3.6, 1.4, 0.2], 16 [5.4, 3.9, 1.7, 0.4], 17 [4.6, 3.4, 1.4, 0.3], 18 [5. , 3.4, 1.5, 0.2], 19 [4.4, 2.9, 1.4, 0.2], 20 [4.9, 3.1, 1.5, 0.1]]) 21 [5] 22 from sklearn.model_selection import train_test_split 23 X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=666 ) 24 scikit-learn中的StandardScaler 25 [6] 26 from sklearn.preprocessing import StandardScaler 27 [7] 28 standardScaler = StandardScaler() 29 [8] 30 standardScaler.fit(X_train) 31 StandardScaler() 32 [9] 33 standardScaler.mean_ 34 array([5.81619048, 3.08761905, 3.66952381, 1.15714286]) 35 [10] 36 standardScaler.scale_ # .std 旧形式已经弃用 , scale表示数据分布范围 37 array([0.80747977, 0.43789436, 1.76166176, 0.75464998]) 38 [11] 39 standardScaler.transform(X_train) 40 array([[-0.63926119, 1.39846731, -1.2315212 , -1.26832688], 41 [-1.01078752, 0.94173615, -1.17475662, -0.73827982], 42 [-1.75384019, -0.42845732, -1.28828579, -1.26832688], 43 [-0.02005063, -0.88518848, 0.13082885, 0.05679076], 44 [-0.7631033 , 0.71337057, -1.28828579, -1.26832688], 45 [-1.50615597, 0.71337057, -1.28828579, -1.13581511], 46 [ 0.84684415, 0.25663941, 0.81200388, 1.11688486], 47 [-0.14389274, -0.42845732, 0.30112261, 0.18930252], 48 [ 0.97068626, -0.20009175, 0.41465178, 0.32181428], 49 [ 0.2276336 , -0.42845732, 0.47141637, 0.45432605], 50 [-1.38231385, 0.25663941, -1.17475662, -1.26832688], 51 [-1.13462963, 1.17010173, -1.28828579, -1.40083864], 52 [ 1.09452838, 0.02827383, 1.09582681, 1.64693191], 53 [ 0.59915993, -0.88518848, 0.69847471, 0.85186133], 54 [ 0.35147571, -0.6568229 , 0.58494554, 0.05679076], 55 [ 0.47531782, -0.6568229 , 0.64171013, 0.85186133], 56 [-0.14389274, 2.99702636, -1.2315212 , -1.00330335], 57 [ 0.59915993, -1.34191964, 0.69847471, 0.45432605], 58 [ 0.72300204, -0.42845732, 0.3578872 , 0.18930252], 59 [-0.88694541, 1.62683289, -1.00446286, -1.00330335], 60 [ 1.21837049, -0.6568229 , 0.64171013, 0.32181428], 61 [-0.88694541, 0.94173615, -1.28828579, -1.13581511], 62 [-1.8776823 , -0.20009175, -1.45857955, -1.40083864], 63 [ 0.10379148, -0.20009175, 0.81200388, 0.85186133], 64 [ 0.72300204, -0.6568229 , 1.09582681, 1.24939662], 65 [-0.26773485, -0.6568229 , 0.69847471, 1.11688486], 66 [-0.39157696, -1.57028522, 0.01729968, -0.20823277], 67 [ 1.3422126 , 0.02827383, 0.69847471, 0.45432605], 68 [ 0.59915993, 0.71337057, 1.09582681, 1.64693191], 69 [ 0.84684415, -0.20009175, 1.20935598, 1.38190839], 70 [-0.14389274, 1.62683289, -1.11799203, -1.13581511], 71 [ 0.97068626, -0.42845732, 0.52818096, 0.18930252], 72 [ 1.09452838, 0.48500499, 1.1525914 , 1.77944368], 73 [-1.25847174, -0.20009175, -1.28828579, -1.40083864], 74 [-1.01078752, 1.17010173, -1.28828579, -1.26832688], 75 [ 0.2276336 , -0.20009175, 0.64171013, 0.85186133], 76 [-1.01078752, -0.20009175, -1.17475662, -1.26832688], 77 [ 0.35147571, -0.20009175, 0.69847471, 0.85186133], 78 [ 0.72300204, 0.02827383, 1.03906223, 0.85186133], 79 [-0.88694541, 1.39846731, -1.2315212 , -1.00330335], 80 [-0.14389274, -0.20009175, 0.30112261, 0.05679076], 81 [-1.01078752, 0.94173615, -1.34505037, -1.13581511], 82 [-0.88694541, 1.62683289, -1.2315212 , -1.13581511], 83 [-1.50615597, 0.25663941, -1.28828579, -1.26832688], 84 [-0.51541907, -0.20009175, 0.47141637, 0.45432605], 85 [ 0.84684415, -0.6568229 , 0.52818096, 0.45432605], 86 [ 0.35147571, -0.6568229 , 0.18759344, 0.18930252], 87 [-1.25847174, 0.71337057, -1.17475662, -1.26832688], 88 [-0.88694541, 0.48500499, -1.11799203, -0.87079159], 89 [-0.02005063, -0.88518848, 0.81200388, 0.9843731 ], 90 [-0.26773485, -0.20009175, 0.24435803, 0.18930252], 91 [ 0.59915993, -0.6568229 , 0.81200388, 0.45432605], 92 [ 1.09452838, 0.48500499, 1.1525914 , 1.24939662], 93 [ 1.71373893, -0.20009175, 1.20935598, 0.58683781], 94 [ 1.09452838, -0.20009175, 0.86876847, 1.51442015], 95 [-1.13462963, 0.02827383, -1.2315212 , -1.40083864], 96 [-1.13462963, -1.34191964, 0.47141637, 0.71934957], 97 [-0.14389274, -1.34191964, 0.7552393 , 1.11688486], 98 [-1.13462963, -1.57028522, -0.20975866, -0.20823277], 99 [-0.39157696, -1.57028522, 0.07406427, -0.07572101], 100 [ 1.09452838, -1.34191964, 1.20935598, 0.85186133], 101 [ 0.84684415, -0.20009175, 1.03906223, 0.85186133], 102 [-0.14389274, -1.11355406, -0.09622949, -0.20823277], 103 [ 0.2276336 , -2.02701638, 0.7552393 , 0.45432605], 104 [ 1.09452838, 0.02827383, 0.58494554, 0.45432605], 105 [-1.13462963, 0.02827383, -1.2315212 , -1.26832688], 106 [ 0.59915993, -1.34191964, 0.7552393 , 0.9843731 ], 107 [-1.38231385, 0.25663941, -1.34505037, -1.26832688], 108 [ 0.2276336 , -0.88518848, 0.81200388, 0.58683781], 109 [-0.02005063, -1.11355406, 0.18759344, 0.05679076], 110 [ 1.3422126 , 0.25663941, 1.1525914 , 1.51442015], 111 [-1.75384019, -0.20009175, -1.34505037, -1.26832688], 112 [ 1.58989682, -0.20009175, 1.26612057, 1.24939662], 113 [ 1.21837049, 0.25663941, 1.26612057, 1.51442015], 114 [-0.7631033 , 0.94173615, -1.2315212 , -1.26832688], 115 [ 2.58063371, 1.62683289, 1.5499435 , 1.11688486], 116 [ 0.72300204, -0.6568229 , 1.09582681, 1.38190839], 117 [-0.26773485, -0.42845732, -0.0394649 , 0.18930252], 118 [-0.39157696, 2.5402952 , -1.28828579, -1.26832688], 119 [-1.25847174, -0.20009175, -1.28828579, -1.13581511], 120 [ 0.59915993, -0.42845732, 1.09582681, 0.85186133], 121 [-1.75384019, 0.25663941, -1.34505037, -1.26832688], 122 [-0.51541907, 1.85519847, -1.11799203, -1.00330335], 123 [-1.01078752, 0.71337057, -1.17475662, -1.00330335], 124 [ 1.09452838, -0.20009175, 0.7552393 , 0.71934957], 125 [-0.51541907, 1.85519847, -1.34505037, -1.00330335], 126 [ 2.33294949, -0.6568229 , 1.72023726, 1.11688486], 127 [-0.26773485, -0.88518848, 0.30112261, 0.18930252], 128 [ 1.21837049, -0.20009175, 1.03906223, 1.24939662], 129 [-0.39157696, 0.94173615, -1.34505037, -1.26832688], 130 [-1.25847174, 0.71337057, -1.00446286, -1.26832688], 131 [-0.51541907, 0.71337057, -1.11799203, -1.26832688], 132 [ 2.33294949, 1.62683289, 1.72023726, 1.38190839], 133 [ 1.3422126 , 0.02827383, 0.98229764, 1.24939662], 134 [-0.26773485, -1.34191964, 0.13082885, -0.07572101], 135 [-0.88694541, 0.71337057, -1.2315212 , -1.26832688], 136 [-0.88694541, 1.62683289, -1.17475662, -1.26832688], 137 [ 0.35147571, -0.42845732, 0.58494554, 0.32181428], 138 [-0.02005063, 2.08356405, -1.40181496, -1.26832688], 139 [-1.01078752, -2.48374754, -0.09622949, -0.20823277], 140 [ 0.72300204, 0.25663941, 0.47141637, 0.45432605], 141 [ 0.35147571, -0.20009175, 0.52818096, 0.32181428], 142 [ 0.10379148, 0.25663941, 0.64171013, 0.85186133], 143 [ 0.2276336 , -2.02701638, 0.18759344, -0.20823277], 144 [ 1.96142316, -0.6568229 , 1.37964974, 0.9843731 ]]) 145 [12] 146 X_train 147 array([[5.3, 3.7, 1.5, 0.2], 148 [5. , 3.5, 1.6, 0.6], 149 [4.4, 2.9, 1.4, 0.2], 150 [5.8, 2.7, 3.9, 1.2], 151 [5.2, 3.4, 1.4, 0.2], 152 [4.6, 3.4, 1.4, 0.3], 153 [6.5, 3.2, 5.1, 2. ], 154 [5.7, 2.9, 4.2, 1.3], 155 [6.6, 3. , 4.4, 1.4], 156 [6. , 2.9, 4.5, 1.5], 157 [4.7, 3.2, 1.6, 0.2], 158 [4.9, 3.6, 1.4, 0.1], 159 [6.7, 3.1, 5.6, 2.4], 160 [6.3, 2.7, 4.9, 1.8], 161 [6.1, 2.8, 4.7, 1.2], 162 [6.2, 2.8, 4.8, 1.8], 163 [5.7, 4.4, 1.5, 0.4], 164 [6.3, 2.5, 4.9, 1.5], 165 [6.4, 2.9, 4.3, 1.3], 166 [5.1, 3.8, 1.9, 0.4], 167 [6.8, 2.8, 4.8, 1.4], 168 [5.1, 3.5, 1.4, 0.3], 169 [4.3, 3. , 1.1, 0.1], 170 [5.9, 3. , 5.1, 1.8], 171 [6.4, 2.8, 5.6, 2.1], 172 [5.6, 2.8, 4.9, 2. ], 173 [5.5, 2.4, 3.7, 1. ], 174 [6.9, 3.1, 4.9, 1.5], 175 [6.3, 3.4, 5.6, 2.4], 176 [6.5, 3. , 5.8, 2.2], 177 [5.7, 3.8, 1.7, 0.3], 178 [6.6, 2.9, 4.6, 1.3], 179 [6.7, 3.3, 5.7, 2.5], 180 [4.8, 3. , 1.4, 0.1], 181 [5. , 3.6, 1.4, 0.2], 182 [6. , 3. , 4.8, 1.8], 183 [5. , 3. , 1.6, 0.2], 184 [6.1, 3. , 4.9, 1.8], 185 [6.4, 3.1, 5.5, 1.8], 186 [5.1, 3.7, 1.5, 0.4], 187 [5.7, 3. , 4.2, 1.2], 188 [5. , 3.5, 1.3, 0.3], 189 [5.1, 3.8, 1.5, 0.3], 190 [4.6, 3.2, 1.4, 0.2], 191 [5.4, 3. , 4.5, 1.5], 192 [6.5, 2.8, 4.6, 1.5], 193 [6.1, 2.8, 4. , 1.3], 194 [4.8, 3.4, 1.6, 0.2], 195 [5.1, 3.3, 1.7, 0.5], 196 [5.8, 2.7, 5.1, 1.9], 197 [5.6, 3. , 4.1, 1.3], 198 [6.3, 2.8, 5.1, 1.5], 199 [6.7, 3.3, 5.7, 2.1], 200 [7.2, 3. , 5.8, 1.6], 201 [6.7, 3. , 5.2, 2.3], 202 [4.9, 3.1, 1.5, 0.1], 203 [4.9, 2.5, 4.5, 1.7], 204 [5.7, 2.5, 5. , 2. ], 205 [4.9, 2.4, 3.3, 1. ], 206 [5.5, 2.4, 3.8, 1.1], 207 [6.7, 2.5, 5.8, 1.8], 208 [6.5, 3. , 5.5, 1.8], 209 [5.7, 2.6, 3.5, 1. ], 210 [6. , 2.2, 5. , 1.5], 211 [6.7, 3.1, 4.7, 1.5], 212 [4.9, 3.1, 1.5, 0.2], 213 [6.3, 2.5, 5. , 1.9], 214 [4.7, 3.2, 1.3, 0.2], 215 [6. , 2.7, 5.1, 1.6], 216 [5.8, 2.6, 4. , 1.2], 217 [6.9, 3.2, 5.7, 2.3], 218 [4.4, 3. , 1.3, 0.2], 219 [7.1, 3. , 5.9, 2.1], 220 [6.8, 3.2, 5.9, 2.3], 221 [5.2, 3.5, 1.5, 0.2], 222 [7.9, 3.8, 6.4, 2. ], 223 [6.4, 2.8, 5.6, 2.2], 224 [5.6, 2.9, 3.6, 1.3], 225 [5.5, 4.2, 1.4, 0.2], 226 [4.8, 3. , 1.4, 0.3], 227 [6.3, 2.9, 5.6, 1.8], 228 [4.4, 3.2, 1.3, 0.2], 229 [5.4, 3.9, 1.7, 0.4], 230 [5. , 3.4, 1.6, 0.4], 231 [6.7, 3. , 5. , 1.7], 232 [5.4, 3.9, 1.3, 0.4], 233 [7.7, 2.8, 6.7, 2. ], 234 [5.6, 2.7, 4.2, 1.3], 235 [6.8, 3. , 5.5, 2.1], 236 [5.5, 3.5, 1.3, 0.2], 237 [4.8, 3.4, 1.9, 0.2], 238 [5.4, 3.4, 1.7, 0.2], 239 [7.7, 3.8, 6.7, 2.2], 240 [6.9, 3.1, 5.4, 2.1], 241 [5.6, 2.5, 3.9, 1.1], 242 [5.1, 3.4, 1.5, 0.2], 243 [5.1, 3.8, 1.6, 0.2], 244 [6.1, 2.9, 4.7, 1.4], 245 [5.8, 4. , 1.2, 0.2], 246 [5. , 2. , 3.5, 1. ], 247 [6.4, 3.2, 4.5, 1.5], 248 [6.1, 3. , 4.6, 1.4], 249 [5.9, 3.2, 4.8, 1.8], 250 [6. , 2.2, 4. , 1. ], 251 [7.4, 2.8, 6.1, 1.9]]) 252 [13] 253 X_train = standardScaler.transform(X_train) 254 [14] 255 X_train 256 array([[-0.63926119, 1.39846731, -1.2315212 , -1.26832688], 257 [-1.01078752, 0.94173615, -1.17475662, -0.73827982], 258 [-1.75384019, -0.42845732, -1.28828579, -1.26832688], 259 [-0.02005063, -0.88518848, 0.13082885, 0.05679076], 260 [-0.7631033 , 0.71337057, -1.28828579, -1.26832688], 261 [-1.50615597, 0.71337057, -1.28828579, -1.13581511], 262 [ 0.84684415, 0.25663941, 0.81200388, 1.11688486], 263 [-0.14389274, -0.42845732, 0.30112261, 0.18930252], 264 [ 0.97068626, -0.20009175, 0.41465178, 0.32181428], 265 [ 0.2276336 , -0.42845732, 0.47141637, 0.45432605], 266 [-1.38231385, 0.25663941, -1.17475662, -1.26832688], 267 [-1.13462963, 1.17010173, -1.28828579, -1.40083864], 268 [ 1.09452838, 0.02827383, 1.09582681, 1.64693191], 269 [ 0.59915993, -0.88518848, 0.69847471, 0.85186133], 270 [ 0.35147571, -0.6568229 , 0.58494554, 0.05679076], 271 [ 0.47531782, -0.6568229 , 0.64171013, 0.85186133], 272 [-0.14389274, 2.99702636, -1.2315212 , -1.00330335], 273 [ 0.59915993, -1.34191964, 0.69847471, 0.45432605], 274 [ 0.72300204, -0.42845732, 0.3578872 , 0.18930252], 275 [-0.88694541, 1.62683289, -1.00446286, -1.00330335], 276 [ 1.21837049, -0.6568229 , 0.64171013, 0.32181428], 277 [-0.88694541, 0.94173615, -1.28828579, -1.13581511], 278 [-1.8776823 , -0.20009175, -1.45857955, -1.40083864], 279 [ 0.10379148, -0.20009175, 0.81200388, 0.85186133], 280 [ 0.72300204, -0.6568229 , 1.09582681, 1.24939662], 281 [-0.26773485, -0.6568229 , 0.69847471, 1.11688486], 282 [-0.39157696, -1.57028522, 0.01729968, -0.20823277], 283 [ 1.3422126 , 0.02827383, 0.69847471, 0.45432605], 284 [ 0.59915993, 0.71337057, 1.09582681, 1.64693191], 285 [ 0.84684415, -0.20009175, 1.20935598, 1.38190839], 286 [-0.14389274, 1.62683289, -1.11799203, -1.13581511], 287 [ 0.97068626, -0.42845732, 0.52818096, 0.18930252], 288 [ 1.09452838, 0.48500499, 1.1525914 , 1.77944368], 289 [-1.25847174, -0.20009175, -1.28828579, -1.40083864], 290 [-1.01078752, 1.17010173, -1.28828579, -1.26832688], 291 [ 0.2276336 , -0.20009175, 0.64171013, 0.85186133], 292 [-1.01078752, -0.20009175, -1.17475662, -1.26832688], 293 [ 0.35147571, -0.20009175, 0.69847471, 0.85186133], 294 [ 0.72300204, 0.02827383, 1.03906223, 0.85186133], 295 [-0.88694541, 1.39846731, -1.2315212 , -1.00330335], 296 [-0.14389274, -0.20009175, 0.30112261, 0.05679076], 297 [-1.01078752, 0.94173615, -1.34505037, -1.13581511], 298 [-0.88694541, 1.62683289, -1.2315212 , -1.13581511], 299 [-1.50615597, 0.25663941, -1.28828579, -1.26832688], 300 [-0.51541907, -0.20009175, 0.47141637, 0.45432605], 301 [ 0.84684415, -0.6568229 , 0.52818096, 0.45432605], 302 [ 0.35147571, -0.6568229 , 0.18759344, 0.18930252], 303 [-1.25847174, 0.71337057, -1.17475662, -1.26832688], 304 [-0.88694541, 0.48500499, -1.11799203, -0.87079159], 305 [-0.02005063, -0.88518848, 0.81200388, 0.9843731 ], 306 [-0.26773485, -0.20009175, 0.24435803, 0.18930252], 307 [ 0.59915993, -0.6568229 , 0.81200388, 0.45432605], 308 [ 1.09452838, 0.48500499, 1.1525914 , 1.24939662], 309 [ 1.71373893, -0.20009175, 1.20935598, 0.58683781], 310 [ 1.09452838, -0.20009175, 0.86876847, 1.51442015], 311 [-1.13462963, 0.02827383, -1.2315212 , -1.40083864], 312 [-1.13462963, -1.34191964, 0.47141637, 0.71934957], 313 [-0.14389274, -1.34191964, 0.7552393 , 1.11688486], 314 [-1.13462963, -1.57028522, -0.20975866, -0.20823277], 315 [-0.39157696, -1.57028522, 0.07406427, -0.07572101], 316 [ 1.09452838, -1.34191964, 1.20935598, 0.85186133], 317 [ 0.84684415, -0.20009175, 1.03906223, 0.85186133], 318 [-0.14389274, -1.11355406, -0.09622949, -0.20823277], 319 [ 0.2276336 , -2.02701638, 0.7552393 , 0.45432605], 320 [ 1.09452838, 0.02827383, 0.58494554, 0.45432605], 321 [-1.13462963, 0.02827383, -1.2315212 , -1.26832688], 322 [ 0.59915993, -1.34191964, 0.7552393 , 0.9843731 ], 323 [-1.38231385, 0.25663941, -1.34505037, -1.26832688], 324 [ 0.2276336 , -0.88518848, 0.81200388, 0.58683781], 325 [-0.02005063, -1.11355406, 0.18759344, 0.05679076], 326 [ 1.3422126 , 0.25663941, 1.1525914 , 1.51442015], 327 [-1.75384019, -0.20009175, -1.34505037, -1.26832688], 328 [ 1.58989682, -0.20009175, 1.26612057, 1.24939662], 329 [ 1.21837049, 0.25663941, 1.26612057, 1.51442015], 330 [-0.7631033 , 0.94173615, -1.2315212 , -1.26832688], 331 [ 2.58063371, 1.62683289, 1.5499435 , 1.11688486], 332 [ 0.72300204, -0.6568229 , 1.09582681, 1.38190839], 333 [-0.26773485, -0.42845732, -0.0394649 , 0.18930252], 334 [-0.39157696, 2.5402952 , -1.28828579, -1.26832688], 335 [-1.25847174, -0.20009175, -1.28828579, -1.13581511], 336 [ 0.59915993, -0.42845732, 1.09582681, 0.85186133], 337 [-1.75384019, 0.25663941, -1.34505037, -1.26832688], 338 [-0.51541907, 1.85519847, -1.11799203, -1.00330335], 339 [-1.01078752, 0.71337057, -1.17475662, -1.00330335], 340 [ 1.09452838, -0.20009175, 0.7552393 , 0.71934957], 341 [-0.51541907, 1.85519847, -1.34505037, -1.00330335], 342 [ 2.33294949, -0.6568229 , 1.72023726, 1.11688486], 343 [-0.26773485, -0.88518848, 0.30112261, 0.18930252], 344 [ 1.21837049, -0.20009175, 1.03906223, 1.24939662], 345 [-0.39157696, 0.94173615, -1.34505037, -1.26832688], 346 [-1.25847174, 0.71337057, -1.00446286, -1.26832688], 347 [-0.51541907, 0.71337057, -1.11799203, -1.26832688], 348 [ 2.33294949, 1.62683289, 1.72023726, 1.38190839], 349 [ 1.3422126 , 0.02827383, 0.98229764, 1.24939662], 350 [-0.26773485, -1.34191964, 0.13082885, -0.07572101], 351 [-0.88694541, 0.71337057, -1.2315212 , -1.26832688], 352 [-0.88694541, 1.62683289, -1.17475662, -1.26832688], 353 [ 0.35147571, -0.42845732, 0.58494554, 0.32181428], 354 [-0.02005063, 2.08356405, -1.40181496, -1.26832688], 355 [-1.01078752, -2.48374754, -0.09622949, -0.20823277], 356 [ 0.72300204, 0.25663941, 0.47141637, 0.45432605], 357 [ 0.35147571, -0.20009175, 0.52818096, 0.32181428], 358 [ 0.10379148, 0.25663941, 0.64171013, 0.85186133], 359 [ 0.2276336 , -2.02701638, 0.18759344, -0.20823277], 360 [ 1.96142316, -0.6568229 , 1.37964974, 0.9843731 ]]) 361 [15] 362 X_test_standard = standardScaler.transform(X_test) 363 [16] 364 X_test_standard 365 array([[-0.26773485, -0.20009175, 0.47141637, 0.45432605], 366 [-0.02005063, -0.6568229 , 0.81200388, 1.64693191], 367 [-1.01078752, -1.7986508 , -0.20975866, -0.20823277], 368 [-0.02005063, -0.88518848, 0.81200388, 0.9843731 ], 369 [-1.50615597, 0.02827383, -1.2315212 , -1.26832688], 370 [-0.39157696, -1.34191964, 0.18759344, 0.18930252], 371 [-0.14389274, -0.6568229 , 0.47141637, 0.18930252], 372 [ 0.84684415, -0.20009175, 0.86876847, 1.11688486], 373 [ 0.59915993, -1.7986508 , 0.41465178, 0.18930252], 374 [-0.39157696, -1.11355406, 0.41465178, 0.05679076], 375 [ 1.09452838, 0.02827383, 0.41465178, 0.32181428], 376 [-1.62999808, -1.7986508 , -1.34505037, -1.13581511], 377 [-1.25847174, 0.02827383, -1.17475662, -1.26832688], 378 [-0.51541907, 0.71337057, -1.2315212 , -1.00330335], 379 [ 1.71373893, 1.17010173, 1.37964974, 1.77944368], 380 [-0.02005063, -0.88518848, 0.24435803, -0.20823277], 381 [-1.50615597, 1.17010173, -1.51534413, -1.26832688], 382 [ 1.71373893, 0.25663941, 1.32288516, 0.85186133], 383 [ 1.3422126 , 0.02827383, 0.81200388, 1.51442015], 384 [ 0.72300204, -0.88518848, 0.92553306, 0.9843731 ], 385 [ 0.59915993, 0.48500499, 0.58494554, 0.58683781], 386 [-1.01078752, 0.71337057, -1.2315212 , -1.26832688], 387 [ 2.33294949, -1.11355406, 1.83376643, 1.51442015], 388 [-1.01078752, 0.48500499, -1.28828579, -1.26832688], 389 [ 0.47531782, -0.42845732, 0.3578872 , 0.18930252], 390 [ 0.10379148, -0.20009175, 0.30112261, 0.45432605], 391 [-1.01078752, 0.25663941, -1.40181496, -1.26832688], 392 [-0.39157696, -1.7986508 , 0.18759344, 0.18930252], 393 [ 0.59915993, 0.48500499, 1.32288516, 1.77944368], 394 [ 2.33294949, -0.20009175, 1.37964974, 1.51442015], 395 [-0.88694541, 0.94173615, -1.28828579, -1.26832688], 396 [-1.13462963, -0.20009175, -1.28828579, -1.26832688], 397 [-0.14389274, -0.6568229 , 0.24435803, 0.18930252], 398 [ 0.47531782, 0.71337057, 0.98229764, 1.51442015], 399 [-0.88694541, -1.34191964, -0.38005242, -0.07572101], 400 [ 1.46605471, 0.25663941, 0.58494554, 0.32181428], 401 [ 0.35147571, -1.11355406, 1.09582681, 0.32181428], 402 [ 2.20910738, -0.20009175, 1.66347267, 1.24939662], 403 [-0.7631033 , 2.31192962, -1.2315212 , -1.40083864], 404 [ 0.47531782, -2.02701638, 0.47141637, 0.45432605], 405 [ 1.83758104, -0.42845732, 1.49317891, 0.85186133], 406 [ 0.72300204, 0.25663941, 0.92553306, 1.51442015], 407 [ 0.2276336 , 0.71337057, 0.47141637, 0.58683781], 408 [-0.7631033 , -0.88518848, 0.13082885, 0.32181428], 409 [-0.51541907, 1.39846731, -1.2315212 , -1.26832688]]) 410 [17] 411 from sklearn.neighbors import KNeighborsClassifier 412 [18] 413 knn_clf = KNeighborsClassifier(n_neighbors=3) 414 [19] 415 knn_clf.fit(X_train,y_train) 416 KNeighborsClassifier(n_neighbors=3) 417 [20] 418 knn_clf.score(X_test_standard,y_test) 419 0.9777777777777777 420 [21] 421 knn_clf.score(X_test,y_test) 422 0.3333333333333333
4-9 更多有关k近邻算法的思考






浙公网安备 33010602011771号