Python数据预处理之打乱数据集
import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/',one_hot=True) train_imgs = mnist.train.images train_labels = mnist.train.labels test_imgs = mnist.test.images test_label_imgs = mnist.test.labels # 取训练数据的20% validate_datasets = 0.2 # 打乱的索引序列 permutation = np.random.permutation(train_labels.shape[0]) validate_indexs = permutation[:int(train_labels.shape[0]*validate_datasets)] train_indexs = permutation[int(train_labels.shape[0]*validate_datasets):] x_train_imgs = train_imgs[train_indexs,:] y_train_labels = train_labels[train_indexs,:] validate_imgs = train_imgs[validate_indexs,:] validate_labels = train_labels[validate_indexs,:]
作者:符号哥
微信公众号:左侧为二维码
个人技术网站-编程符号网:http://www.itfh.cn
个人技术网站-IT源码网:http://www.itym.cn
新浪微博:https://weibo.com/u/2814576687
如果你想及时得到个人撰写文章以及著作的消息推送,或者想看看个人推荐的技术资料,可以扫描左边二维码(或者长按识别二维码)关注个人公众号。
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

浙公网安备 33010602011771号