Tensoflow常用的DataReaders

TensorFlow中有三种方法可以用来读取数据,第一种是直接使用constant给常量赋值;第二种是使用dict进行feed,但是速度会很慢,因为需要先把文件整个读到内存中。

第三种也是最常用的是使用TF中自带的datareaders来进行:TF 包含了四种DataReaders:
tf.TextLineReader
最常用的,可以按行读取文件。
tf.FixedLengthRecordReader
可以输出所有整个文件的所有数据(要求所有数据长度,格式相同)
tf.WholeFIleReader
直接输出整个文件(一般用于一个文件就是一组数据的情况)
tf.TFRecordReader
读取TF自身的二进制格式文件(TFRecord)
tf.ReaderBase
可以作为基类来定制自己的Reader.

要使用data reader,首先建立一个包含所有文件名的队列:

filename_queue = tf.train.string_input_producer(["heart.csv"])
reader = tf.TextLineReader(skip_header_lines=1)

可以认为reader是类似于生成器一样的东西,每次调用都会产生不同的键值对:

Key, value = reader.read(filename_queue)

Tf.train.string_input_producer 建立一个先进先出队列,这里需要用到tf.Coordinator 和 tf.QueueRunner。

with tf.Session() as sess:
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    print(sess.run(key)) # data/heart.csv:2 
    print(sess.run(value)) # 144,0.01,4.41,28.61,Absent,55,28.87,2.06,63,1 coord.request_stop()     
    coord.join(threads) 

这里coord是用来协调线程的。

然后还可以用csv解析器来对读取的值进行处理,比如144,0.01,4.41,28.61,Absent,55,28.87,2.06,63,1中前九个是value,最后一个1是class。
Content = tf.decode_csv(value, record_defaults=record_defaults)
解析器的两个参数分别指定了每一列的值的类型和每个值的默认类型,如果有数据的某个值是空的,它会按照我们设定的方式进行填充,填充方法如下:

record_defaults = [[1.0] for _ in range(N_FEATURES)] # define all features to be floats
record_defaults[4] = [''] # make the fifth feature string record_defaults.append([1]) content = tf.decode_csv(value, record_defaults=record_defaults)

  

posted @ 2017-09-04 11:31  dabney  阅读(294)  评论(0编辑  收藏  举报