tf数据读取

# coding:utf-8
import os
if not os.path.exists('read'):
    os.makedirs('read/')

# 导入TensorFlow
import tensorflow as tf 

# 新建一个Session
with tf.Session() as sess:
    # 我们要读三幅图片A.jpg, B.jpg, C.jpg
    filename = ['A.jpg', 'B.jpg', 'C.jpg']
    # string_input_producer会产生一个文件名队列
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
    # reader从文件名队列中读数据。对应的方法是reader.read
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    # tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 获取图片数据并保存
        image_data = sess.run(value)
        with open('read/test_%d.jpg' % i, 'wb') as f:
            f.write(image_data)
# 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志

1.tf.global_variables_initializer()和tf.local_variables_initializer()

tf.global_variables_initializer()添加节点用于初始化所有的变量,在构建完整个模型并在会话中加载模型后,运行这个节点。
能够将所有的变量一步到位的初始化。通过feed_dict, 你也可以将指定的列表传递给它,只初始化列表中的变量。
示例代码如下:

sess.run(tf.global_variables_initializer(), 
feed_dict={
        learning_rate_dis: learning_rate_val_dis,
        adam_beta1_d_tf: adam_beta1_d,
        learning_rate_proj: learning_rate_val_proj,
        lambda_ratio_tf: lambda_ratio,
        }) 

tf.local_variables_initializer()返回一个初始化所有局部变量的操作(Op)

sess.run(tf.local_variables_initializer(), 
feed_dict={
        learning_rate_dis: learning_rate_val_dis,
        adam_beta1_d_tf: adam_beta1_d,
        learning_rate_proj: learning_rate_val_proj,
        })

局部变量:在函数内定义的变量。
全局变量:在函数外定义的变量。

2 tf.FixedLengthRecordReader和tf.WholeFileReader函数

tf.WholeFileReader将文件的全部内容作为值输出的Reader。

如果要使用,请在队列(Queue)中的排列文件名。Read的输出将是一个文件名(key)和该文件的内容(value)。

 tf.FixedLengthRecordReader是读取固定长度字节数信息(针对bin文件使用FixedLengthRecordReader读取比较合适)

下次调用时会接着上次读取的位置继续读取文件,而不会从头开始读取。

3 tensorflow读取CIFAR-10数据

(1)用tf.train.string_input_producer建立队列

(2)用reader.read()读数据  tf.FixedLengthRecordReader和tf.WholeFileReader函数

(3)调用tf.train.start_queue_runners

(4)通过 sess.run()取图片结果

4 tensorflow读取数据3种方法

(1)用占位符(placeholder)读入

  (2)  用队列形式建立文件到tensor的映射

(3)用Dataset API读入数据

 



参考:21个项目玩转深度学习

https://blog.csdn.net/yyhhlancelot/article/details/81415137  

posted on 2018-12-28 11:00  yq~~  阅读(267)  评论(0)    收藏  举报