• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
火磷
Memory will fade,but not notes.
博客园    首页    新随笔    联系   管理    订阅  订阅
tensorflow中数据批次划分示例教程

1.简介

将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。

1.1 tf.train

tf.train.slice_input_producer(tensor_list,shuffle=True,seed=None,capacity=32)

tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,allow_smaller_final_batch=False)

参数说明:

shuffle:为True时进行数据清洗

allow_smaller_final_batch:为True时将小于batch_size的批次值输出

-------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------------------------------------------------

1.2 tf.data.Dataset

tf.data.Dataset是一个类,可以使用以下方法:

from_tensor_slices(tensors)

batch(batch_size,drop_remainder=False)

shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

repeat(count=None)

make_one_shot_iterator() / get_next()

注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器

参数说明:

tensors:可以是列表、字典、元组等类型

drop_remainder:为False时表示不保留小于batch_size的批次,否则删除

buffer_size:数据清洗时使用的buffer大小

count:对应为epoch个数,为None时表示数据序列无限延续

2.示例

2.1 使用tf.train.slice_input_producer和tf.train.batch

 1 import tensorflow as tf
 2 import numpy as np
 3 import math
 4 
 5 # 生成样例数据集
 6 def generate_data():
 7     num = 15
 8     labels = np.asarray(range(num))
 9     images = np.random.random([num, 5, 5, 3])
10     return images, labels
11 
12 # 打印样例信息
13 images, labels = generate_data()
14 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
15 
16 # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数
17 n_epochs = 3
18 batch_size = 6
19 train_nums = 15
20 iterations = math.ceil(train_nums/batch_size)
21 
22 # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据
23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False)
24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32)
25 print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape))
26 
27 
28 with tf.Session() as sess:
29     tf.global_variables_initializer().run()
30     # 启动队列线程
31     coord = tf.train.Coordinator()
32     threads = tf.train.start_queue_runners(sess, coord)
33     # 打印信息
34     for epoch in range(n_epochs):       
35         for iteration in range(iterations):
36             cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch])
37             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
38     # 接收线程
39     coord.request_stop()
40     coord.join(threads)    
41 
42 
43 # 打印结果如下
44 images.shape=(15, 5, 5, 3), labels.shape=(15,)
45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
47 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
48 The 1 epoch, the 3 iteration, current batch is [12 13 14  0  1  2]
49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8]
50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14]
51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5]
52 The 3 epoch, the 1 iteration, current batch is [ 6  7  8  9 10 11]
53 The 3 epoch, the 2 iteration, current batch is [12 13 14  0  1  2]
54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]

如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
 3 The 1 epoch, the 1 iteration, current batch is [ 2  5  8 11  3 10]
 4 The 1 epoch, the 2 iteration, current batch is [ 9 12  7  1 14 13]
 5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6]
 6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13  5]
 7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7]
 8 The 2 epoch, the 3 iteration, current batch is [10 13  1  4 12  3]
 9 The 3 epoch, the 1 iteration, current batch is [ 2  8  5  9 14  7]
10 The 3 epoch, the 2 iteration, current batch is [ 0 11  6  1 14  9]
11 The 3 epoch, the 3 iteration, current batch is [11  6 12  7  0 13]

如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 3 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
 4 The 1 epoch, the 3 iteration, current batch is [12 13 14]
 5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 6 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
 7 The 2 epoch, the 3 iteration, current batch is [12 13 14]
 8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 9 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
10 The 3 epoch, the 3 iteration, current batch is [12 13 14]

2.2 使用tf.data.Dataset类

 1 import tensorflow as tf
 2 import numpy as np
 3 import math
 4 
 5 # 生成样例数据集
 6 def generate_data():
 7     num = 15
 8     labels = np.asarray(range(num))
 9     images = np.random.random([num, 5, 5, 3])
10     return images, labels
11 # 打印样例信息
12 images, labels = generate_data()
13 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
14 
15 # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数
16 n_epochs = 3
17 batch_size = 6
18 train_nums = 15
19 iterations = math.ceil(train_nums/batch_size)
20 
21 # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续
22 dataset = tf.data.Dataset.from_tensor_slices((images, labels))
23 dataset = dataset.batch(batch_size).repeat()
24 
25 # 使用生成器make_one_shot_iterator和get_next取数据
26 iterator = dataset.make_one_shot_iterator()
27 next_iterator = iterator.get_next()
28 
29 with tf.Session() as sess:
30     for epoch in range(n_epochs):
31         for iteration in range(iterations):
32             cu_image_batch, cu_label_batch = sess.run(next_iterator)
33             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
34 
35 
36 # 结果如下:
37 images.shape=(15, 5, 5, 3), labels.shape=(15,)
38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
39 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
40 The 1 epoch, the 3 iteration, current batch is [12 13 14]
41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
42 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
43 The 2 epoch, the 3 iteration, current batch is [12 13 14]
44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
45 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
46 The 3 epoch, the 3 iteration, current batch is [12 13 14]

使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 The 1 epoch, the 1 iteration, current batch is [ 7  4 10  8  3 11]
 3 The 1 epoch, the 2 iteration, current batch is [ 0  2 12 13 14  5]
 4 The 1 epoch, the 3 iteration, current batch is [6 9 1]
 5 The 2 epoch, the 1 iteration, current batch is [ 6 14  7  9  3  8]
 6 The 2 epoch, the 2 iteration, current batch is [13  5 12  1 11  2]
 7 The 2 epoch, the 3 iteration, current batch is [ 0  4 10]
 8 The 3 epoch, the 1 iteration, current batch is [10  8 13 12  3 14]
 9 The 3 epoch, the 2 iteration, current batch is [ 6  9  2  5  1 11]
10 The 3 epoch, the 3 iteration, current batch is [0 4 7]

!!!

 

posted on 2018-11-13 10:22  火磷  阅读(3249)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3