import tensorflow as tf
x1 = tf.constant([[1.0, 2., 3.], [4., 5., 6.],[7., 8.,9.], [10., 11.,12.]])
y1 = tf.constant([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5],[6.5, 7.5, 8.5], [9.5, 10.5, 11.5]])
# 创建dataset
dataset = tf.data.Dataset.from_tensor_slices((x1, y1))
dataset = dataset.shuffle(100).batch(3).repeat()
# iterator = dataset.make_one_shot_iterator()#对应不需要初始化,不能更改数据源
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(4):
value = sess.run(next_element)
sess.run(next_element)
print(value)