knn 数据集准备

 1 """
 2 Created on Mon Aug 26 20:57:24 2019
 3 
 4 @author: huoqs
 5 
 6 knn algorithm
 7 """
 8 import numpy as np
 9 import matplotlib.pyplot as plt
10 
11 def generate_data(num_samples, num_features=2):
12     data_size = (num_samples, num_features)
13     data = np.random.randint(0, 100, data_size)
14     
15     label_size = (num_samples, 1)
16     labels = np.random.randint(0, 2, label_size)
17     # must be float32
18     return data.astype(np.float32), labels
19 
20 def plot_data(all_blue, all_red):
21     plt.scatter(all_blue[:, 0], all_blue[:, 1], c = 'b', marker = 's', s = 180)
22     plt.scatter(all_red[:, 0], all_red[:, 1], c = 'r', marker = '^', s = 180)
23     plt.xlabel('x')
24     plt.ylabel('y')
25 
26 plt.style.use('ggplot')
27 
28 np.random.seed(42)
29 
30 train_data, labels = generate_data(11)
31 
32 # print(train_data, labels)
33 
34 blue = train_data[labels.ravel() == 0]
35 red = train_data[labels.ravel() == 1]
36 
37 plot_data(blue, red)

知识点:

1、np.random.randint 函数,生成一个数组,参数:low,high,size,type

https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.random.randint.html

2、ndarray.ravel(),将数组扁平化,变为一维数组

https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ravel.html

 

posted @ 2019-08-26 21:27  霍霍  阅读(690)  评论(0编辑  收藏  举报