tensorflow读取数据集生成batch——tf.data.Dataset.from_tensor_slices

import tensorflow as tf
x1 = tf.constant([[1.0, 2., 3.], [4., 5., 6.],[7., 8.,9.], [10., 11.,12.]])
y1 = tf.constant([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5],[6.5, 7.5, 8.5], [9.5, 10.5, 11.5]])
# 创建dataset
dataset = tf.data.Dataset.from_tensor_slices((x1, y1))
dataset = dataset.shuffle(100).batch(3).repeat()
# iterator = dataset.make_one_shot_iterator()#对应不需要初始化,不能更改数据源
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(4):
        value = sess.run(next_element)
        sess.run(next_element)
        print(value)

 

posted @ 2020-04-05 15:58  致若千尘  阅读(4117)  评论(0)    收藏  举报